diff options
Diffstat (limited to 'pkg/tcpip')
111 files changed, 10004 insertions, 3212 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 81f762e10..91971b687 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "reflect" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -116,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker { v = ip.TTL() case header.IPv6: v = ip.HopLimit() + case *ipv6HeaderWithExtHdr: + v = ip.HopLimit() + default: + t.Fatalf("unrecognized header type %T for TTL evaluation", ip) } if v != ttl { t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) @@ -216,6 +221,42 @@ func IPv4Options(want header.IPv4Options) NetworkChecker { } } +// IPv4RouterAlert returns a checker that checks that the RouterAlert option is +// set in an IPv4 packet. +func IPv4RouterAlert() NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + ip, ok := h[0].(header.IPv4) + if !ok { + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) + } + iterator := ip.Options().MakeIterator() + for { + opt, done, err := iterator.Next() + if err != nil { + t.Fatalf("error acquiring next IPv4 option %s", err) + } + if done { + break + } + if opt.Type() != header.IPv4OptionRouterAlertType { + continue + } + want := [header.IPv4OptionRouterAlertLength]byte{ + byte(header.IPv4OptionRouterAlertType), + header.IPv4OptionRouterAlertLength, + header.IPv4OptionRouterAlertValue, + header.IPv4OptionRouterAlertValue, + } + if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { + t.Errorf("router alert option mismatch (-want +got):\n%s", diff) + } + return + } + t.Errorf("failed to find router alert option in %v", ip.Options()) + } +} + // FragmentOffset creates a checker that checks the FragmentOffset field. func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -284,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress +// field in ControlMessages. +func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasOriginalDstAddress { + t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) + } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { + t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -1012,6 +1066,74 @@ func ICMPv6Payload(want []byte) TransportChecker { } } +// MLD creates a checker that checks that the packet contains a valid MLD +// message for type of mldType, with potentially additional checks specified by +// checkers. +// +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// MLD message as far as the size of the message (minSize) is concerned. The +// values within the message are up to checkers to validate. +func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + // Check normal ICMPv6 first. + ICMPv6( + ICMPv6Type(msgType), + ICMPv6Code(0))(t, h) + + last := h[len(h)-1] + + icmp := header.ICMPv6(last.Payload()) + if got := len(icmp.MessageBody()); got < minSize { + t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) + } + + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay +// field of a MLD message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid MLD message as far as the size is concerned. +func MLDMaxRespDelay(want time.Duration) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.MLD(icmp.MessageBody()) + + if got := ns.MaximumResponseDelay(); got != want { + t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) + } + } +} + +// MLDMulticastAddress creates a checker that checks the Multicast Address +// field of a MLD message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid MLD message as far as the size is concerned. +func MLDMulticastAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.MLD(icmp.MessageBody()) + + if got := ns.MulticastAddress(); got != want { + t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want) + } + } +} + // NDP creates a checker that checks that the packet contains a valid NDP // message for type of ty, with potentially additional checks specified by // checkers. @@ -1031,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N last := h[len(h)-1] icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.NDPPayload()); got < minSize { + if got := len(icmp.MessageBody()); got < minSize { t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) } @@ -1065,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) if got := ns.TargetAddress(); got != want { t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) @@ -1094,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) if got := na.TargetAddress(); got != want { t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) @@ -1112,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) if got := na.SolicitedFlag(); got != want { t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) @@ -1183,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) ndpOptions(t, na.Options(), opts) } } @@ -1198,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ndpOptions(t, ns.Options(), opts) } } @@ -1223,7 +1345,261 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - rs := header.NDPRouterSolicit(icmp.NDPPayload()) + rs := header.NDPRouterSolicit(icmp.MessageBody()) ndpOptions(t, rs.Options(), opts) } } + +// IGMP checks the validity and properties of the given IGMP packet. It is +// expected to be used in conjunction with other IGMP transport checkers for +// specific properties. +func IGMP(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) + } + + igmp := header.IGMP(last.Payload()) + for _, f := range checkers { + f(t, igmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// IGMPType creates a checker that checks the IGMP Type field. +func IGMPType(want header.IGMPType) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.Type(); got != want { + t.Errorf("got igmp.Type() = %d, want = %d", got, want) + } + } +} + +// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. +func IGMPMaxRespTime(want time.Duration) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.MaxRespTime(); got != want { + t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) + } + } +} + +// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. +func IGMPGroupAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.GroupAddress(); got != want { + t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) + } + } +} + +// IPv6ExtHdrChecker is a function to check an extension header. +type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) + +// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. +func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + + ipv6 := header.IPv6(b) + if !ipv6.IsValid(len(b)) { + t.Error("not a valid IPv6 packet") + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var rawPayloadHeader header.IPv6RawPayloadHeader + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) + return + } + r, ok := h.(header.IPv6RawPayloadHeader) + if ok { + rawPayloadHeader = r + break + } + } + + networkHeader := ipv6HeaderWithExtHdr{ + IPv6: ipv6, + transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), + payload: rawPayloadHeader.Buf.ToView(), + } + + for _, checker := range checkers { + checker(t, []header.Network{&networkHeader}) + } +} + +// IPv6ExtHdr checks for the presence of extension headers. +// +// All the extension headers in headers will be checked exhaustively in the +// order provided. +func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) + if !ok { + t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), + buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), + ) + + for _, check := range headers { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) + return + } + check(t, h) + } + // Validate we consumed all headers. + // + // The next one over should be a raw payload and then iterator should + // terminate. + wantDone := false + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done != wantDone { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) + return + } + if done { + break + } + if _, ok := h.(header.IPv6RawPayloadHeader); !ok { + t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) + continue + } + wantDone = true + } + } +} + +var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) + +// ipv6HeaderWithExtHdr provides a header.Network implementation that takes +// extension headers into consideration, which is not the case with vanilla +// header.IPv6. +type ipv6HeaderWithExtHdr struct { + header.IPv6 + transport tcpip.TransportProtocolNumber + payload []byte +} + +// TransportProtocol implements header.Network. +func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { + return h.transport +} + +// Payload implements header.Network. +func (h *ipv6HeaderWithExtHdr) Payload() []byte { + return h.payload +} + +// IPv6ExtHdrOptionChecker is a function to check an extension header option. +type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) + +// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop +// extension header and validates the containing options with checkers. +// +// checkers must exhaustively contain all the expected options. +func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { + return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { + t.Helper() + + hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) + if !ok { + t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) + return + } + optionsIterator := hbh.Iter() + for _, f := range checkers { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + f(t, opt) + } + // Validate all options were consumed. + for { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if !done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + if done { + break + } + } + } +} + +// IPv6RouterAlert validates that an extension header option is the RouterAlert +// option and matches on its value. +func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + routerAlert, ok := opt.(*header.IPv6RouterAlertOption) + if !ok { + t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) + return + } + if routerAlert.Value != want { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) + } + } +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index d87797617..0bdc12d53 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -11,11 +11,13 @@ go_library( "gue.go", "icmpv4.go", "icmpv6.go", + "igmp.go", "interfaces.go", "ipv4.go", "ipv6.go", "ipv6_extension_headers.go", "ipv6_fragment.go", + "mld.go", "ndp_neighbor_advert.go", "ndp_neighbor_solicit.go", "ndp_options.go", @@ -39,6 +41,8 @@ go_test( size = "small", srcs = [ "checksum_test.go", + "igmp_test.go", + "ipv4_test.go", "ipv6_test.go", "ipversion_test.go", "tcp_test.go", @@ -58,6 +62,7 @@ go_test( srcs = [ "eth_test.go", "ipv6_extension_headers_test.go", + "mld_test.go", "ndp_test.go", ], library = ":header", diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 4303fc5d5..2eef64b4d 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -115,6 +115,12 @@ const ( ICMPv6NeighborSolicit ICMPv6Type = 135 ICMPv6NeighborAdvert ICMPv6Type = 136 ICMPv6RedirectMsg ICMPv6Type = 137 + + // Multicast Listener Discovery (MLD) messages, see RFC 2710. + + ICMPv6MulticastListenerQuery ICMPv6Type = 130 + ICMPv6MulticastListenerReport ICMPv6Type = 131 + ICMPv6MulticastListenerDone ICMPv6Type = 132 ) // IsErrorType returns true if the receiver is an ICMP error type. @@ -245,10 +251,9 @@ func (b ICMPv6) SetSequence(sequence uint16) { binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) } -// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6 -// packet's message body as defined by RFC 4443 section 2.1; the portion of the -// ICMPv6 buffer after the first ICMPv6HeaderSize bytes. -func (b ICMPv6) NDPPayload() []byte { +// MessageBody returns the message body as defined by RFC 4443 section 2.1; the +// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes. +func (b ICMPv6) MessageBody() []byte { return b[ICMPv6HeaderSize:] } diff --git a/pkg/tcpip/header/igmp.go b/pkg/tcpip/header/igmp.go new file mode 100644 index 000000000..5c5be1b9d --- /dev/null +++ b/pkg/tcpip/header/igmp.go @@ -0,0 +1,181 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// IGMP represents an IGMP header stored in a byte array. +type IGMP []byte + +// IGMP implements `Transport`. +var _ Transport = (*IGMP)(nil) + +const ( + // IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes, + // as per RFC 2236, Section 2, Page 2. + IGMPMinimumSize = 8 + + // IGMPQueryMinimumSize is the minimum size of a valid Membership Query + // Message in bytes, as per RFC 2236, Section 2, Page 2. + IGMPQueryMinimumSize = 8 + + // IGMPReportMinimumSize is the minimum size of a valid Report Message in + // bytes, as per RFC 2236, Section 2, Page 2. + IGMPReportMinimumSize = 8 + + // IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message + // in bytes, as per RFC 2236, Section 2, Page 2. + IGMPLeaveMessageMinimumSize = 8 + + // IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page + // 3. + IGMPTTL = 1 + + // igmpTypeOffset defines the offset of the type field in an IGMP message. + igmpTypeOffset = 0 + + // igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an + // IGMP message. + igmpMaxRespTimeOffset = 1 + + // igmpChecksumOffset defines the offset of the checksum field in an IGMP + // message. + igmpChecksumOffset = 2 + + // igmpGroupAddressOffset defines the offset of the Group Address field in an + // IGMP message. + igmpGroupAddressOffset = 4 + + // IGMPProtocolNumber is IGMP's transport protocol number. + IGMPProtocolNumber tcpip.TransportProtocolNumber = 2 +) + +// IGMPType is the IGMP type field as per RFC 2236. +type IGMPType byte + +// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2. +// Descriptions below come from there. +const ( + // IGMPMembershipQuery indicates that the message type is Membership Query. + // "There are two sub-types of Membership Query messages: + // - General Query, used to learn which groups have members on an + // attached network. + // - Group-Specific Query, used to learn if a particular group + // has any members on an attached network. + // These two messages are differentiated by the Group Address, as + // described in section 1.4 ." + IGMPMembershipQuery IGMPType = 0x11 + // IGMPv1MembershipReport indicates that the message is a Membership Report + // generated by a host using the IGMPv1 protocol: "an additional type of + // message, for backwards-compatibility with IGMPv1" + IGMPv1MembershipReport IGMPType = 0x12 + // IGMPv2MembershipReport indicates that the Message type is a Membership + // Report generated by a host using the IGMPv2 protocol. + IGMPv2MembershipReport IGMPType = 0x16 + // IGMPLeaveGroup indicates that the message type is a Leave Group + // notification message. + IGMPLeaveGroup IGMPType = 0x17 +) + +// Type is the IGMP type field. +func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) } + +// SetType sets the IGMP type field. +func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) } + +// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership +// Query messages, in other cases it is set to 0 by the sender and ignored by +// the receiver. +func (b IGMP) MaxRespTime() time.Duration { + // As per RFC 2236 section 2.2, + // + // The Max Response Time field is meaningful only in Membership Query + // messages, and specifies the maximum allowed time before sending a + // responding report in units of 1/10 second. In all other messages, it + // is set to zero by the sender and ignored by receivers. + return DecisecondToDuration(b[igmpMaxRespTimeOffset]) +} + +// SetMaxRespTime sets the MaxRespTimeField. +func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m } + +// Checksum is the IGMP checksum field. +func (b IGMP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[igmpChecksumOffset:]) +} + +// SetChecksum sets the IGMP checksum field. +func (b IGMP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum) +} + +// GroupAddress gets the Group Address field. +func (b IGMP) GroupAddress() tcpip.Address { + return tcpip.Address(b[igmpGroupAddressOffset:][:IPv4AddressSize]) +} + +// SetGroupAddress sets the Group Address field. +func (b IGMP) SetGroupAddress(address tcpip.Address) { + if n := copy(b[igmpGroupAddressOffset:], address); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize)) + } +} + +// SourcePort implements Transport.SourcePort. +func (IGMP) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (IGMP) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (IGMP) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (IGMP) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (IGMP) Payload() []byte { + return nil +} + +// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP +// header. +func IGMPCalculateChecksum(h IGMP) uint16 { + // The header contains a checksum itself, set it aside to avoid checksumming + // the checksum and replace it afterwards. + existingXsum := h.Checksum() + h.SetChecksum(0) + xsum := ^Checksum(h, 0) + h.SetChecksum(existingXsum) + return xsum +} + +// DecisecondToDuration converts a value representing deci-seconds to a +// time.Duration. +func DecisecondToDuration(ds uint8) time.Duration { + return time.Duration(ds) * time.Second / 10 +} diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go new file mode 100644 index 000000000..b6126d29a --- /dev/null +++ b/pkg/tcpip/header/igmp_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// TestIGMPHeader tests the functions within header.igmp +func TestIGMPHeader(t *testing.T) { + const maxRespTimeTenthSec = 0xF0 + b := []byte{ + 0x11, // IGMP Type, Membership Query + maxRespTimeTenthSec, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want { + t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) + } + + if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) + } + + igmpType := header.IGMPv2MembershipReport + igmpHeader.SetType(igmpType) + if got := igmpHeader.Type(); got != igmpType { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType) + } + if got := header.IGMPType(b[0]); got != igmpType { + t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType) + } + + respTime := byte(0x02) + igmpHeader.SetMaxRespTime(respTime) + if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want { + t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) + } + + checksum := uint16(0x0102) + igmpHeader.SetChecksum(checksum) + if got := igmpHeader.Checksum(); got != checksum { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) + } + + groupAddress := tcpip.Address("\x04\x03\x02\x01") + igmpHeader.SetGroupAddress(groupAddress) + if got := igmpHeader.GroupAddress(); got != groupAddress { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) + } +} + +// TestIGMPChecksum ensures that the checksum calculator produces the expected +// checksum. +func TestIGMPChecksum(t *testing.T) { + b := []byte{ + 0x11, // IGMP Type, Membership Query + 0xF0, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + // Calculate the initial checksum after setting the checksum temporarily to 0 + // to avoid checksumming the checksum. + initialChecksum := igmpHeader.Checksum() + igmpHeader.SetChecksum(0) + checksum := ^header.Checksum(b, 0) + igmpHeader.SetChecksum(initialChecksum) + + if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum { + t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum) + } +} + +func TestDecisecondToDuration(t *testing.T) { + const valueInDeciseconds = 5 + if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want { + t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want) + } +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 91fe7b6a5..e6103f4bc 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -100,7 +100,7 @@ type IPv4Fields struct { // // That leaves ten 32 bit (4 byte) fields for options. An attempt to encode // more will fail. - Options IPv4Options + Options IPv4OptionsSerializer } // IPv4 is an IPv4 header. @@ -157,6 +157,9 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + // IPv4AllRoutersGroup is a multicast address for all routers. + IPv4AllRoutersGroup tcpip.Address = "\xe0\x00\x00\x02" + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP // packet that every IPv4 capable host must be able to // process/reassemble. @@ -282,18 +285,17 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } -// IPv4Options is a buffer that holds all the raw IP options. -type IPv4Options []byte - -// SizeWithPadding implements stack.NetOptions. -// It reports the size to allocate for the Options. RFC 791 page 23 (end of -// section 3.1) says of the padding at the end of the options: +// padIPv4OptionsLength returns the total length for IPv4 options of length l +// after applying padding according to RFC 791: // The internet header padding is used to ensure that the internet // header ends on a 32 bit boundary. -func (o IPv4Options) SizeWithPadding() int { - return (len(o) + IPv4IHLStride - 1) & ^(IPv4IHLStride - 1) +func padIPv4OptionsLength(length uint8) uint8 { + return (length + IPv4IHLStride - 1) & ^uint8(IPv4IHLStride-1) } +// IPv4Options is a buffer that holds all the raw IP options. +type IPv4Options []byte + // Options returns a buffer holding the options. func (b IPv4) Options() IPv4Options { hdrLen := b.HeaderLength() @@ -372,26 +374,16 @@ func (b IPv4) CalculateChecksum() uint16 { func (b IPv4) Encode(i *IPv4Fields) { // The size of the options defines the size of the whole header and thus the // IHL field. Options are rare and this is a heavily used function so it is - // worth a bit of optimisation here to keep the copy out of the fast path. - hdrLen := IPv4MinimumSize + // worth a bit of optimisation here to keep the serializer out of the fast + // path. + hdrLen := uint8(IPv4MinimumSize) if len(i.Options) != 0 { - // SizeWithPadding is always >= len(i.Options). - aLen := i.Options.SizeWithPadding() - hdrLen += aLen - if hdrLen > len(b) { - panic(fmt.Sprintf("encode received %d bytes, wanted >= %d", len(b), hdrLen)) - } - opts := b[options:] - // This avoids bounds checks on the next line(s) which would happen even - // if there's no work to do. - if n := copy(opts, i.Options); n != aLen { - padding := opts[n:][:aLen-n] - for i := range padding { - padding[i] = 0 - } - } + hdrLen += i.Options.Serialize(b[options:]) + } + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("%d is larger than maximum IPv4 header size of %d", hdrLen, IPv4MaximumHeaderSize)) } - b.SetHeaderLength(uint8(hdrLen)) + b.SetHeaderLength(hdrLen) b[tos] = i.TOS b.SetTotalLength(i.TotalLength) binary.BigEndian.PutUint16(b[id:], i.ID) @@ -471,6 +463,10 @@ const ( // options and may appear multiple times. IPv4OptionNOPType IPv4OptionType = 1 + // IPv4OptionRouterAlertType is the option type for the Router Alert option, + // defined in RFC 2113 Section 2.1. + IPv4OptionRouterAlertType IPv4OptionType = 20 | 0x80 + // IPv4OptionRecordRouteType is used by each router on the path of the packet // to record its path. It is carried over to an Echo Reply. IPv4OptionRecordRouteType IPv4OptionType = 7 @@ -871,3 +867,162 @@ func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } // Contents implements IPv4Option. func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } + +// Router Alert option specific related constants. +// +// from RFC 2113 section 2.1: +// +// +--------+--------+--------+--------+ +// |10010100|00000100| 2 octet value | +// +--------+--------+--------+--------+ +// +// Type: +// Copied flag: 1 (all fragments must carry the option) +// Option class: 0 (control) +// Option number: 20 (decimal) +// +// Length: 4 +// +// Value: A two octet code with the following values: +// 0 - Router shall examine packet +// 1-65535 - Reserved +const ( + // IPv4OptionRouterAlertLength is the length of a Router Alert option. + IPv4OptionRouterAlertLength = 4 + + // IPv4OptionRouterAlertValue is the only permissible value of the 16 bit + // payload of the router alert option. + IPv4OptionRouterAlertValue = 0 + + // iPv4OptionRouterAlertValueOffset is the offset for the value of a + // RouterAlert option. + iPv4OptionRouterAlertValueOffset = 2 +) + +// IPv4SerializableOption is an interface to represent serializable IPv4 option +// types. +type IPv4SerializableOption interface { + // optionType returns the type identifier of the option. + optionType() IPv4OptionType +} + +// IPv4SerializableOptionPayload is an interface providing serialization of the +// payload of an IPv4 option. +type IPv4SerializableOptionPayload interface { + // length returns the size of the payload. + length() uint8 + + // serializeInto serializes the payload into the provided byte buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // Length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MUST panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto(buffer []byte) uint8 +} + +// IPv4OptionsSerializer is a serializer for IPv4 options. +type IPv4OptionsSerializer []IPv4SerializableOption + +// Length returns the total number of bytes required to serialize the options. +func (s IPv4OptionsSerializer) Length() uint8 { + var total uint8 + for _, opt := range s { + total++ + if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok { + // Add 1 to reported length to account for the length byte. + total += 1 + withPayload.length() + } + } + return padIPv4OptionsLength(total) +} + +// Serialize serializes the provided list of IPV4 options into b. +// +// Note, b must be of sufficient size to hold all the options in s. See +// IPv4OptionsSerializer.Length for details on the getting the total size +// of a serialized IPv4OptionsSerializer. +// +// Serialize panics if b is not of sufficient size to hold all the options in s. +func (s IPv4OptionsSerializer) Serialize(b []byte) uint8 { + var total uint8 + for _, opt := range s { + ty := opt.optionType() + if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok { + // Serialize first to reduce bounds checks. + l := 2 + withPayload.serializeInto(b[2:]) + b[0] = byte(ty) + b[1] = l + b = b[l:] + total += l + continue + } + // Options without payload consist only of the type field. + // + // NB: Repeating code from the branch above is intentional to minimize + // bounds checks. + b[0] = byte(ty) + b = b[1:] + total++ + } + + // According to RFC 791: + // + // The internet header padding is used to ensure that the internet + // header ends on a 32 bit boundary. The padding is zero. + padded := padIPv4OptionsLength(total) + b = b[:padded-total] + for i := range b { + b[i] = 0 + } + return padded +} + +var _ IPv4SerializableOptionPayload = (*IPv4SerializableRouterAlertOption)(nil) +var _ IPv4SerializableOption = (*IPv4SerializableRouterAlertOption)(nil) + +// IPv4SerializableRouterAlertOption provides serialization of the Router Alert +// IPv4 option according to RFC 2113. +type IPv4SerializableRouterAlertOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableRouterAlertOption) optionType() IPv4OptionType { + return IPv4OptionRouterAlertType +} + +// Length implements IPv4SerializableOption. +func (*IPv4SerializableRouterAlertOption) length() uint8 { + return IPv4OptionRouterAlertLength - iPv4OptionRouterAlertValueOffset +} + +// SerializeInto implements IPv4SerializableOption. +func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 { + binary.BigEndian.PutUint16(buffer, IPv4OptionRouterAlertValue) + return o.length() +} + +var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil) + +// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option. +type IPv4SerializableNOPOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableNOPOption) optionType() IPv4OptionType { + return IPv4OptionNOPType +} + +var _ IPv4SerializableOption = (*IPv4SerializableListEndOption)(nil) + +// IPv4SerializableListEndOption provides serialization for the IPv4 List End +// option. +type IPv4SerializableListEndOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableListEndOption) optionType() IPv4OptionType { + return IPv4OptionListEndType +} diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go new file mode 100644 index 000000000..6475cd694 --- /dev/null +++ b/pkg/tcpip/header/ipv4_test.go @@ -0,0 +1,179 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestIPv4OptionsSerializer(t *testing.T) { + optCases := []struct { + name string + option []header.IPv4SerializableOption + expect []byte + }{ + { + name: "NOP", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableNOPOption{}, + }, + expect: []byte{1, 0, 0, 0}, + }, + { + name: "ListEnd", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableListEndOption{}, + }, + expect: []byte{0, 0, 0, 0}, + }, + { + name: "RouterAlert", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableRouterAlertOption{}, + }, + expect: []byte{148, 4, 0, 0}, + }, { + name: "NOP and RouterAlert", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableNOPOption{}, + &header.IPv4SerializableRouterAlertOption{}, + }, + expect: []byte{1, 148, 4, 0, 0, 0, 0, 0}, + }, + } + + for _, opt := range optCases { + t.Run(opt.name, func(t *testing.T) { + s := header.IPv4OptionsSerializer(opt.option) + l := s.Length() + if got := len(opt.expect); got != int(l) { + t.Fatalf("s.Length() = %d, want = %d", got, l) + } + b := make([]byte, l) + for i := range b { + // Fill the buffer with full bytes to ensure padding is being set + // correctly. + b[i] = 0xFF + } + if serializedLength := s.Serialize(b); serializedLength != l { + t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l) + } + if diff := cmp.Diff(opt.expect, b); diff != "" { + t.Errorf("mismatched serialized option (-want +got):\n%s", diff) + } + }) + } +} + +// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested +// fields when options are supplied. +func TestIPv4EncodeOptions(t *testing.T) { + tests := []struct { + name string + numberOfNops int + encodedOptions header.IPv4Options // reply should look like this + wantIHL int + }{ + { + name: "valid no options", + wantIHL: header.IPv4MinimumSize, + }, + { + name: "one byte options", + numberOfNops: 1, + encodedOptions: header.IPv4Options{1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "two byte options", + numberOfNops: 2, + encodedOptions: header.IPv4Options{1, 1, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "three byte options", + numberOfNops: 3, + encodedOptions: header.IPv4Options{1, 1, 1, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "four byte options", + numberOfNops: 4, + encodedOptions: header.IPv4Options{1, 1, 1, 1}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "five byte options", + numberOfNops: 5, + encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 8, + }, + { + name: "thirty nine byte options", + numberOfNops: 39, + encodedOptions: header.IPv4Options{ + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, + }, + wantIHL: header.IPv4MinimumSize + 40, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops)) + for i := range serializeOpts { + serializeOpts[i] = &header.IPv4SerializableNOPOption{} + } + paddedOptionLength := serializeOpts.Length() + ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength) + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength) + hdr := buffer.NewPrependable(int(totalLen)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + // To check the padding works, poison the last byte of the options space. + if paddedOptionLength != serializeOpts.Length() { + ip.SetHeaderLength(uint8(ipHeaderLength)) + ip.Options()[paddedOptionLength-1] = 0xff + ip.SetHeaderLength(0) + } + ip.Encode(&header.IPv4Fields{ + Options: serializeOpts, + }) + options := ip.Options() + wantOptions := test.encodedOptions + if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { + t.Errorf("got IHL of %d, want %d", got, want) + } + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(wantOptions) == 0 && len(options) == 0 { + return + } + + if diff := cmp.Diff(wantOptions, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 55d09355a..d522e5f10 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -48,11 +48,13 @@ type IPv6Fields struct { // FlowLabel is the "flow label" field of an IPv6 packet. FlowLabel uint32 - // PayloadLength is the "payload length" field of an IPv6 packet. + // PayloadLength is the "payload length" field of an IPv6 packet, including + // the length of all extension headers. PayloadLength uint16 - // NextHeader is the "next header" field of an IPv6 packet. - NextHeader uint8 + // TransportProtocol is the transport layer protocol number. Serialized in the + // last "next header" field of the IPv6 header + extension headers. + TransportProtocol tcpip.TransportProtocolNumber // HopLimit is the "Hop Limit" field of an IPv6 packet. HopLimit uint8 @@ -62,6 +64,9 @@ type IPv6Fields struct { // DstAddr is the "destination ip address" of an IPv6 packet. DstAddr tcpip.Address + + // ExtensionHeaders are the extension headers following the IPv6 header. + ExtensionHeaders IPv6ExtHdrSerializer } // IPv6 represents an ipv6 header stored in a byte array. @@ -253,12 +258,14 @@ func (IPv6) SetChecksum(uint16) { // Encode encodes all the fields of the ipv6 header. func (b IPv6) Encode(i *IPv6Fields) { + extHdr := b[IPv6MinimumSize:] b.SetTOS(i.TrafficClass, i.FlowLabel) b.SetPayloadLength(i.PayloadLength) - b[IPv6NextHeaderOffset] = i.NextHeader b[hopLimit] = i.HopLimit b.SetSourceAddress(i.SrcAddr) b.SetDestinationAddress(i.DstAddr) + nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr) + b[IPv6NextHeaderOffset] = nextHeader } // IsValid performs basic validation on the packet. diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 583c2c5d3..f18981332 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -18,9 +18,12 @@ import ( "bufio" "bytes" "encoding/binary" + "errors" "fmt" "io" + "math" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -47,6 +50,11 @@ const ( // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end // of an IPv6 payload, as per RFC 8200 section 4.7. IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59 + + // IPv6UnknownExtHdrIdentifier is reserved by IANA. + // https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header + // "254 Use for experimentation and testing [RFC3692][RFC4727]" + IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254 ) const ( @@ -70,8 +78,8 @@ const ( // Fragment Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetOffset = 0 - // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to - // discard from the Fragment Offset. + // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment + // Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetShift = 3 // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an @@ -109,6 +117,37 @@ const ( IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8 ) +// padIPv6OptionsLength returns the total length for IPv6 options of length l +// considering the 8-octet alignment as stated in RFC 8200 Section 4.2. +func padIPv6OptionsLength(length int) int { + return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1) +} + +// padIPv6Option fills b with the appropriate padding options depending on its +// length. +func padIPv6Option(b []byte) { + switch len(b) { + case 0: // No padding needed. + case 1: // Pad with Pad1. + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier) + default: // Pad with PadN. + s := b[ipv6ExtHdrOptionPayloadOffset:] + for i := range s { + s[i] = 0 + } + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier) + b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s)) + } +} + +// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to +// serialize an option at headerOffset with alignment requirements +// [align]n + alignOffset. +func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int { + padLen := headerOffset - alignOffset + return ((padLen + align - 1) & ^(align - 1)) - padLen +} + // IPv6PayloadHeader is implemented by the various headers that can be found // in an IPv6 payload. // @@ -201,29 +240,55 @@ type IPv6ExtHdrOption interface { isIPv6ExtHdrOption() } -// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier. -type IPv6ExtHdrOptionIndentifier uint8 +// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier. +type IPv6ExtHdrOptionIdentifier uint8 const ( // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that // provides 1 byte padding, as outlined in RFC 8200 section 4.2. - ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0 + ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0 // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that // provides variable length byte padding, as outlined in RFC 8200 section 4.2. - ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1 + ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1 + + // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router + // Alert Hop by Hop option as defined in RFC 2711 section 2.1. + ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5 + + // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header + // option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionTypeOffset = 0 + + // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionLengthOffset = 1 + + // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionPayloadOffset = 2 ) +// ipv6UnknownActionFromIdentifier maps an extension header option's +// identifier's high bits to the action to take when the identifier is unknown. +func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction { + return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) +} + +// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option +// is malformed. +var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option") + // IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension // header option that is unknown by the parsing utilities. type IPv6UnknownExtHdrOption struct { - Identifier IPv6ExtHdrOptionIndentifier + Identifier IPv6ExtHdrOptionIdentifier Data []byte } // UnknownAction implements IPv6OptionUnknownAction.UnknownAction. func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction { - return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) + return ipv6UnknownActionFromIdentifier(o.Identifier) } // isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption. @@ -246,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // options buffer has been exhausted and we are done iterating. return nil, true, nil } - id := IPv6ExtHdrOptionIndentifier(temp) + id := IPv6ExtHdrOptionIdentifier(temp) // If the option identifier indicates the option is a Pad1 option, then we // know the option does not have Length and Data fields. End processing of @@ -289,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } continue + case ipv6RouterAlertHopByHopOptionIdentifier: + var routerAlertValue [ipv6RouterAlertPayloadLength]byte + if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil { + switch err { + case io.EOF, io.ErrUnexpectedEOF: + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + default: + return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err) + } + } else if n != int(length) { + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + } + return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil default: bytes := make([]byte, length) if n, err := io.ReadFull(&i.reader, bytes); err != nil { @@ -452,9 +530,11 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { // Since we consume the iterator, we return the payload as is. buf = i.payload - // Mark i as done. + // Mark i as done, but keep track of where we were for error reporting. *i = IPv6PayloadIterator{ nextHdrIdentifier: IPv6NoNextHeaderIdentifier, + headerOffset: i.headerOffset, + nextOffset: i.nextOffset, } } else { buf = i.payload.Clone(nil) @@ -602,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil } + +// IPv6SerializableExtHdr provides serialization for IPv6 extension +// headers. +type IPv6SerializableExtHdr interface { + // identifier returns the assigned IPv6 header identifier for this extension + // header. + identifier() IPv6ExtensionHeaderIdentifier + + // length returns the total serialized length in bytes of this extension + // header, including the common next header and length fields. + length() int + + // serializeInto serializes the receiver into the provided byte + // buffer and with the provided nextHeader value. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto returns the number of bytes that was used to serialize the + // receiver. Implementers must only use the number of bytes required to + // serialize the receiver. Callers MAY provide a larger buffer than required + // to serialize into. + serializeInto(nextHeader uint8, b []byte) int +} + +var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil) + +// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop +// options extension header. +type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption + +const ( + // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field + // in a hop by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrNextHeaderOffset = 0 + + // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop + // by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrLengthOffset = 1 + + // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by + // hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrOptionsOffset = 2 + + // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet + // words in a hop by hop extension header's length field, as stated in RFC + // 8200 section 4.3: + // Length of the Hop-by-Hop Options header in 8-octet units, + // not including the first 8 octets. + ipv6HopByHopExtHdrUnaccountedLenWords = 1 +) + +// identifier implements IPv6SerializableExtHdr. +func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6HopByHopOptionsExtHdrIdentifier +} + +// length implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) length() int { + var total int + for _, opt := range h { + align, alignOffset := opt.alignment() + total += ipv6OptionsAlignmentPadding(total, align, alignOffset) + total += ipv6ExtHdrOptionPayloadOffset + int(opt.length()) + } + // Account for next header and total length fields and add padding. + return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total) +} + +// serializeInto implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int { + optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:] + totalLength := ipv6HopByHopExtHdrOptionsOffset + for _, opt := range h { + // Calculate alignment requirements and pad buffer if necessary. + align, alignOffset := opt.alignment() + padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset) + if padLen != 0 { + padIPv6Option(optBuffer[:padLen]) + totalLength += padLen + optBuffer = optBuffer[padLen:] + } + + l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:]) + optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier()) + optBuffer[ipv6ExtHdrOptionLengthOffset] = l + l += ipv6ExtHdrOptionPayloadOffset + totalLength += int(l) + optBuffer = optBuffer[l:] + } + padded := padIPv6OptionsLength(totalLength) + if padded != totalLength { + padIPv6Option(optBuffer[:padded-totalLength]) + totalLength = padded + } + wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords + if wordsLen > math.MaxUint8 { + panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen)) + } + b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader + b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen) + return totalLength +} + +// IPv6SerializableHopByHopOption provides serialization for hop by hop options. +type IPv6SerializableHopByHopOption interface { + // identifier returns the option identifier of this Hop by Hop option. + identifier() IPv6ExtHdrOptionIdentifier + + // length returns the *payload* size of the option (not considering the type + // and length fields). + length() uint8 + + // alignment returns the alignment requirements from this option. + // + // Alignment requirements take the form [align]n + offset as specified in + // RFC 8200 section 4.2. The alignment requirement is on the offset between + // the option type byte and the start of the hop by hop header. + // + // align must be a power of 2. + alignment() (align int, offset int) + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) uint8 +} + +var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil) + +// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in +// RFC 2711 section 2.1. +type IPv6RouterAlertOption struct { + Value IPv6RouterAlertValue +} + +// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option. +type IPv6RouterAlertValue uint16 + +const ( + // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener + // Discovery message as defined in RFC 2711 section 2.1. + IPv6RouterAlertMLD IPv6RouterAlertValue = 0 + // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as + // defined in RFC 2711 section 2.1. + IPv6RouterAlertRSVP IPv6RouterAlertValue = 1 + // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active + // Networks message as defined in RFC 2711 section 2.1. + IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2 + + // ipv6RouterAlertPayloadLength is the length of the Router Alert payload + // as defined in RFC 2711. + ipv6RouterAlertPayloadLength = 2 + + // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the + // Router Alert option defined as 2n+0 in RFC 2711. + ipv6RouterAlertAlignmentRequirement = 2 + + // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset + // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section + // 2.1. + ipv6RouterAlertAlignmentOffsetRequirement = 0 +) + +// UnknownAction implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction { + return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier) +} + +// isIPv6ExtHdrOption implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {} + +// identifier implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier { + return ipv6RouterAlertHopByHopOptionIdentifier +} + +// length implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) length() uint8 { + return ipv6RouterAlertPayloadLength +} + +// alignment implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) alignment() (int, int) { + // From RFC 2711 section 2.1: + // Alignment requirement: 2n+0. + return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 { + binary.BigEndian.PutUint16(b, uint16(o.Value)) + return ipv6RouterAlertPayloadLength +} + +// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers. +type IPv6ExtHdrSerializer []IPv6SerializableExtHdr + +// Serialize serializes the provided list of IPv6 extension headers into b. +// +// Note, b must be of sufficient size to hold all the headers in s. See +// IPv6ExtHdrSerializer.Length for details on the getting the total size of a +// serialized IPv6ExtHdrSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +// +// Serialize takes the transportProtocol value to be used as the last extension +// header's Next Header value and returns the header identifier of the first +// serialized extension header and the total serialized length. +func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) { + nextHeader := uint8(transportProtocol) + if len(s) == 0 { + return nextHeader, 0 + } + var totalLength int + for i, h := range s[:len(s)-1] { + length := h.serializeInto(uint8(s[i+1].identifier()), b) + b = b[length:] + totalLength += length + } + totalLength += s[len(s)-1].serializeInto(nextHeader, b) + return uint8(s[0].identifier()), totalLength +} + +// Length returns the total number of bytes required to serialize the extension +// headers. +func (s IPv6ExtHdrSerializer) Length() int { + var totalLength int + for _, h := range s { + totalLength += h.length() + } + return totalLength +} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go index ab20c5f37..65adc6250 100644 --- a/pkg/tcpip/header/ipv6_extension_headers_test.go +++ b/pkg/tcpip/header/ipv6_extension_headers_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool func TestIPv6UnknownExtHdrOption(t *testing.T) { tests := []struct { name string - identifier IPv6ExtHdrOptionIndentifier + identifier IPv6ExtHdrOptionIdentifier expectedUnknownAction IPv6OptionUnknownAction }{ { @@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) { bytes: []byte{1, 3}, err: io.ErrUnexpectedEOF, }, + { + name: "Router alert without data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with partial data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with partial data and Pad1", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with extra data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with missing data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1}, + err: io.ErrUnexpectedEOF, + }, } check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) { @@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) { }) } } + +var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil) + +// dummyHbHOptionSerializer provides a generic implementation of +// IPv6SerializableHopByHopOption for use in tests. +type dummyHbHOptionSerializer struct { + id IPv6ExtHdrOptionIdentifier + payload []byte + align int + alignOffset int +} + +// identifier implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier { + return s.id +} + +// length implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) length() uint8 { + return uint8(len(s.payload)) +} + +// alignment implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) alignment() (int, int) { + align := 1 + if s.align != 0 { + align = s.align + } + return align, s.alignOffset +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 { + return uint8(copy(b, s.payload)) +} + +func TestIPv6HopByHopSerializer(t *testing.T) { + validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + dummy, ok := serializable.(*dummyHbHOptionSerializer) + if !ok { + t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable) + } + unknown, ok := deserialized.(*IPv6UnknownExtHdrOption) + if !ok { + t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{}) + } + if dummy.id != unknown.Identifier { + t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id) + } + if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" { + t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff) + } + } + tests := []struct { + name string + nextHeader uint8 + options []IPv6SerializableHopByHopOption + expect []byte + validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption) + }{ + { + name: "single option", + nextHeader: 13, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 15, + payload: []byte{9, 8, 7, 6}, + }, + }, + expect: []byte{13, 0, 15, 4, 9, 8, 7, 6}, + validate: validateDummies, + }, + { + name: "short option padN zero", + nextHeader: 88, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5}, + }, + }, + expect: []byte{88, 0, 22, 2, 4, 5, 1, 0}, + validate: validateDummies, + }, + { + name: "short option pad1", + nextHeader: 11, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 33, + payload: []byte{1, 2, 3}, + }, + }, + expect: []byte{11, 0, 33, 3, 1, 2, 3, 0}, + validate: validateDummies, + }, + { + name: "long option padN", + nextHeader: 55, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 77, + payload: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + }, + expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options align 2n", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 2, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0}, + validate: validateDummies, + }, + { + name: "two options align 8n+1", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 8, + alignOffset: 1, + }, + }, + expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0}, + validate: validateDummies, + }, + { + name: "no options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{}, + expect: []byte{33, 0, 1, 4, 0, 0, 0, 0}, + }, + { + name: "Router Alert", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}}, + expect: []byte{33, 0, 5, 2, 0, 0, 1, 0}, + validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + routerAlert, ok := deserialized.(*IPv6RouterAlertOption) + if !ok { + t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized) + } + if routerAlert.Value != IPv6RouterAlertMLD { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD) + } + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6SerializableHopByHopExtHdr(test.options) + length := s.length() + if length != len(test.expect) { + t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect)) + } + b := make([]byte, length) + for i := range b { + // Fill the buffer with ones to ensure all padding is correctly set. + b[i] = 0xFF + } + if got := s.serializeInto(test.nextHeader, b); got != length { + t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length) + } + if diff := cmp.Diff(test.expect, b); diff != "" { + t.Fatalf("serialization mismatch (-want +got):\n%s", diff) + } + + // Deserialize the options and verify them. + optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit + iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter() + for _, testOpt := range test.options { + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + test.validate(t, testOpt, opt) + } + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if !done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + }) + } +} + +var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil) + +// dummyIPv6ExtHdrSerializer provides a generic implementation of +// IPv6SerializableExtHdr for use in tests. +// +// The dummy header always carries the nextHeader value in the first byte. +type dummyIPv6ExtHdrSerializer struct { + id IPv6ExtensionHeaderIdentifier + headerContents []byte +} + +// identifier implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier { + return s.id +} + +// length implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) length() int { + return len(s.headerContents) + 1 +} + +// serializeInto implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int { + b[0] = nextHeader + return copy(b[1:], s.headerContents) + 1 +} + +func TestIPv6ExtHdrSerializer(t *testing.T) { + tests := []struct { + name string + headers []IPv6SerializableExtHdr + nextHeader tcpip.TransportProtocolNumber + expectSerialized []byte + expectNextHeader uint8 + }{ + { + name: "one header", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 15, + headerContents: []byte{1, 2, 3, 4}, + }, + }, + nextHeader: TCPProtocolNumber, + expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4}, + expectNextHeader: 15, + }, + { + name: "two headers", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 22, + headerContents: []byte{1, 2, 3}, + }, + &dummyIPv6ExtHdrSerializer{ + id: 23, + headerContents: []byte{4, 5, 6}, + }, + }, + nextHeader: ICMPv6ProtocolNumber, + expectSerialized: []byte{ + 23, 1, 2, 3, + byte(ICMPv6ProtocolNumber), 4, 5, 6, + }, + expectNextHeader: 22, + }, + { + name: "no headers", + headers: []IPv6SerializableExtHdr{}, + nextHeader: UDPProtocolNumber, + expectSerialized: []byte{}, + expectNextHeader: byte(UDPProtocolNumber), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6ExtHdrSerializer(test.headers) + l := s.Length() + if got, want := l, len(test.expectSerialized); got != want { + t.Fatalf("got serialized length = %d, want = %d", got, want) + } + b := make([]byte, l) + for i := range b { + // Fill the buffer with garbage to make sure we're writing to all bytes. + b[i] = 0xFF + } + nextHeader, serializedLen := s.Serialize(test.nextHeader, b) + if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader { + t.Errorf( + "got s.Serialize(..) = (%d, %d), want = (%d, %d)", + nextHeader, + serializedLen, + test.expectNextHeader, + len(test.expectSerialized), + ) + } + if diff := cmp.Diff(test.expectSerialized, b); diff != "" { + t.Errorf("serialization mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go index 018555a26..9d09f32eb 100644 --- a/pkg/tcpip/header/ipv6_fragment.go +++ b/pkg/tcpip/header/ipv6_fragment.go @@ -27,12 +27,11 @@ const ( idV6 = 4 ) -// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the -// fields of a packet that needs to be encoded. -type IPv6FragmentFields struct { - // NextHeader is the "next header" field of an IPv6 fragment. - NextHeader uint8 +var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil) +// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment +// extension header as defined in RFC 8200 section 4.5. +type IPv6SerializableFragmentExtHdr struct { // FragmentOffset is the "fragment offset" field of an IPv6 fragment. FragmentOffset uint16 @@ -43,6 +42,29 @@ type IPv6FragmentFields struct { Identification uint32 } +// identifier implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6FragmentHeader +} + +// length implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) length() int { + return IPv6FragmentHeaderSize +} + +// serializeInto implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int { + // Prevent too many bounds checks. + _ = b[IPv6FragmentHeaderSize:] + binary.BigEndian.PutUint32(b[idV6:], h.Identification) + binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift) + b[nextHdrFrag] = nextHeader + if h.M { + b[more] |= ipv6FragmentExtHdrMFlagMask + } + return IPv6FragmentHeaderSize +} + // IPv6Fragment represents an ipv6 fragment header stored in a byte array. // Most of the methods of IPv6Fragment access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. @@ -58,16 +80,6 @@ const ( IPv6FragmentHeaderSize = 8 ) -// Encode encodes all the fields of the ipv6 fragment. -func (b IPv6Fragment) Encode(i *IPv6FragmentFields) { - b[nextHdrFrag] = i.NextHeader - binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3) - if i.M { - b[more] |= 1 - } - binary.BigEndian.PutUint32(b[idV6:], i.Identification) -} - // IsValid performs basic validation on the fragment header. func (b IPv6Fragment) IsValid() bool { return len(b) >= IPv6FragmentHeaderSize diff --git a/pkg/tcpip/header/mld.go b/pkg/tcpip/header/mld.go new file mode 100644 index 000000000..ffe03c76a --- /dev/null +++ b/pkg/tcpip/header/mld.go @@ -0,0 +1,103 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // MLDMinimumSize is the minimum size for an MLD message. + MLDMinimumSize = 20 + + // MLDHopLimit is the Hop Limit for all IPv6 packets with an MLD message, as + // per RFC 2710 section 3. + MLDHopLimit = 1 + + // mldMaximumResponseDelayOffset is the offset to the Maximum Response Delay + // field within MLD. + mldMaximumResponseDelayOffset = 0 + + // mldMulticastAddressOffset is the offset to the Multicast Address field + // within MLD. + mldMulticastAddressOffset = 4 +) + +// MLD is a Multicast Listener Discovery message in an ICMPv6 packet. +// +// MLD will only contain the body of an ICMPv6 packet. +// +// As per RFC 2710 section 3, MLD messages have the following format (MLD only +// holds the bytes after the first four bytes in the diagram below): +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Code | Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Maximum Response Delay | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// + + +// | | +// + Multicast Address + +// | | +// + + +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +type MLD []byte + +// MaximumResponseDelay returns the Maximum Response Delay. +func (m MLD) MaximumResponseDelay() time.Duration { + // As per RFC 2710 section 3.4: + // + // The Maximum Response Delay field is meaningful only in Query + // messages, and specifies the maximum allowed delay before sending a + // responding Report, in units of milliseconds. In all other messages, + // it is set to zero by the sender and ignored by receivers. + return time.Duration(binary.BigEndian.Uint16(m[mldMaximumResponseDelayOffset:])) * time.Millisecond +} + +// SetMaximumResponseDelay sets the Maximum Response Delay field. +// +// maxRespDelayMS is the value in milliseconds. +func (m MLD) SetMaximumResponseDelay(maxRespDelayMS uint16) { + binary.BigEndian.PutUint16(m[mldMaximumResponseDelayOffset:], maxRespDelayMS) +} + +// MulticastAddress returns the Multicast Address. +func (m MLD) MulticastAddress() tcpip.Address { + // As per RFC 2710 section 3.5: + // + // In a Query message, the Multicast Address field is set to zero when + // sending a General Query, and set to a specific IPv6 multicast address + // when sending a Multicast-Address-Specific Query. + // + // In a Report or Done message, the Multicast Address field holds a + // specific IPv6 multicast address to which the message sender is + // listening or is ceasing to listen, respectively. + return tcpip.Address(m[mldMulticastAddressOffset:][:IPv6AddressSize]) +} + +// SetMulticastAddress sets the Multicast Address field. +func (m MLD) SetMulticastAddress(multicastAddress tcpip.Address) { + if n := copy(m[mldMulticastAddressOffset:], multicastAddress); n != IPv6AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, IPv6AddressSize)) + } +} diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go new file mode 100644 index 000000000..0cecf10d4 --- /dev/null +++ b/pkg/tcpip/header/mld_test.go @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestMLD(t *testing.T) { + b := []byte{ + // Maximum Response Delay + 0, 0, + + // Reserved + 0, 0, + + // MulticastAddress + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, + } + + const maxRespDelay = 513 + binary.BigEndian.PutUint16(b, maxRespDelay) + + mld := MLD(b) + + if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want { + t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) + } + + const newMaxRespDelay = 1234 + mld.SetMaximumResponseDelay(newMaxRespDelay) + if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want { + t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) + } + + if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want { + t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want) + } + + multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}) + mld.SetMulticastAddress(multicastAddress) + if got := mld.MulticastAddress(); got != multicastAddress { + t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress) + } +} diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 5d3975c56..554242f0c 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -298,7 +298,7 @@ func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { return it, nil } -// Serialize serializes the provided list of NDP options into o. +// Serialize serializes the provided list of NDP options into b. // // Note, b must be of sufficient size to hold all the options in s. See // NDPOptionsSerializer.Length for details on the getting the total size diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 98bdd29db..a6d4fcd59 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -36,10 +36,10 @@ const ( // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { - // SrcPort is the "source port" field of a UDP packet. + // SrcPort is the "Source Port" field of a UDP packet. SrcPort uint16 - // DstPort is the "destination port" field of a UDP packet. + // DstPort is the "Destination Port" field of a UDP packet. DstPort uint16 // Length is the "length" field of a UDP packet. @@ -64,52 +64,57 @@ const ( UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) -// SourcePort returns the "source port" field of the udp header. +// SourcePort returns the "Source Port" field of the UDP header. func (b UDP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } -// DestinationPort returns the "destination port" field of the udp header. +// DestinationPort returns the "Destination Port" field of the UDP header. func (b UDP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } -// Length returns the "length" field of the udp header. +// Length returns the "Length" field of the UDP header. func (b UDP) Length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } // Payload returns the data contained in the UDP datagram. func (b UDP) Payload() []byte { - return b[UDPMinimumSize:] + return b[:b.Length()][UDPMinimumSize:] } -// Checksum returns the "checksum" field of the udp header. +// Checksum returns the "checksum" field of the UDP header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) } -// SetSourcePort sets the "source port" field of the udp header. +// SetSourcePort sets the "source port" field of the UDP header. func (b UDP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[udpSrcPort:], port) } -// SetDestinationPort sets the "destination port" field of the udp header. +// SetDestinationPort sets the "destination port" field of the UDP header. func (b UDP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[udpDstPort:], port) } -// SetChecksum sets the "checksum" field of the udp header. +// SetChecksum sets the "checksum" field of the UDP header. func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } -// SetLength sets the "length" field of the udp header. +// SetLength sets the "length" field of the UDP header. func (b UDP) SetLength(length uint16) { binary.BigEndian.PutUint16(b[udpLength:], length) } -// CalculateChecksum calculates the checksum of the udp packet, given the +// PayloadLength returns the length of the payload following the UDP header. +func (b UDP) PayloadLength() uint16 { + return b.Length() - UDPMinimumSize +} + +// CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 39ca774ef..973f06cbc 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -9,7 +9,6 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c95aef63c..0efbfb22b 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -32,7 +31,7 @@ type PacketInfo struct { Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO - Route stack.Route + Route *stack.Route } // Notification is the interface for receiving notification from the packet @@ -271,21 +270,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := PacketInfo{ - Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }), - Proto: 0, - GSO: nil, - } - - e.q.Write(p) - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index 3eef7cd56..beefcd008 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -62,7 +62,7 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { // WritePacket implements stack.LinkEndpoint. func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress(), proto, pkt) return e.Endpoint.WritePacket(r, gso, proto, pkt) } @@ -71,7 +71,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) + e.AddHeader(linkAddr, r.RemoteLinkAddress(), proto, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, proto) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 975309fc8..cb94cbea6 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher } switch sa.(type) { case *unix.SockaddrLinklayer: - // enable PACKET_FANOUT mode is the underlying socket is - // of type AF_PACKET. - const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG + // Enable PACKET_FANOUT mode if the underlying socket is of type + // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will + // prevent gvisor from receiving fragmented packets and the host does the + // reassembly on our behalf before delivering the fragments. This makes it + // hard to test fragmentation reassembly code in Netstack. + const fanoutType = unix.PACKET_FANOUT_HASH fanoutArg := fanoutID | fanoutType<<16 if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) @@ -410,7 +413,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt) } var builder iovec.Builder @@ -453,7 +456,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { if e.hdrSize > 0 { - e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress(), pkt.NetworkProtocolNumber, pkt) } var vnetHdrBuf []byte @@ -558,11 +561,6 @@ func viewsEqual(vs1, vs2 []buffer.View) bool { return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) -} - // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 709f829c8..ce4da7230 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -183,9 +183,8 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) defer c.cleanup() - r := &stack.Route{ - RemoteLinkAddress: raddr, - } + var r stack.Route + r.ResolveWith(raddr) // Build payload. payload := buffer.NewView(plen) @@ -220,7 +219,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { + if err := c.ep.WritePacket(&r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -325,9 +324,9 @@ func TestPreserveSrcAddress(t *testing.T) { // Set LocalLinkAddress in route to the value of the bridged address. r := &stack.Route{ - RemoteLinkAddress: raddr, - LocalLinkAddress: baddr, + LocalLinkAddress: baddr, } + r.ResolveWith(raddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ // WritePacket panics given a prependable with anything less than diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 38aa694e4..edca57e4e 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -96,23 +96,6 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - // There should be an ethernet header at the beginning of vv. - hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. - return tcpip.ErrBadAddress - } - linkHeader := header.Ethernet(hdr) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) - - return nil -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*endpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareLoopback diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index e7493e5c5..cbda59775 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 56a611825..22e79ce3a 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -17,7 +17,6 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -106,13 +105,6 @@ func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protoco return tcpip.ErrNoRoute } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // WriteRawPacket doesn't get a route or network address, so there's - // nowhere to write this. - return tcpip.ErrNoRoute -} - // InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD index 2cdb23475..00b42b924 100644 --- a/pkg/tcpip/link/nested/BUILD +++ b/pkg/tcpip/link/nested/BUILD @@ -11,7 +11,6 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index d40de54df..0ee54c3d5 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -19,7 +19,6 @@ package nested import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -123,11 +122,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return e.child.WritePackets(r, gso, pkts, protocol) } -// WriteRawPacket implements stack.LinkEndpoint. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - return e.child.WriteRawPacket(vv) -} - // Wait implements stack.LinkEndpoint. func (e *Endpoint) Wait() { e.child.Wait() diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index 3922c2a04..9a1b0c0c2 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -36,14 +36,14 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { // WritePacket implements stack.LinkEndpoint.WritePacket. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress(), r.LocalLinkAddress, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress(), pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, proto) diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 523b0d24b..25c364391 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -55,7 +55,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this // endpoint is the local address on the other end. - e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress() /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), })) @@ -67,11 +67,6 @@ func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint. -func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - panic("not implemented") -} - // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 1d0079bd6..5bea598eb 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -13,7 +13,6 @@ go_library( "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index fc1e34fc7..27667f5f0 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -156,7 +155,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. newRoute := r.Clone() - pkt.EgressRoute = &newRoute + pkt.EgressRoute = newRoute pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] @@ -183,7 +182,7 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // the route here to ensure it doesn't get released while the // packet is still in our queue. newRoute := pkt.EgressRoute.Clone() - pkt.EgressRoute = &newRoute + pkt.EgressRoute = newRoute if !d.q.enqueue(pkt) { if enqueued > 0 { d.newPacketWaker.Assert() @@ -197,13 +196,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB return enqueued, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // TODO(gvisor.dev/issue/3267): Queue these packets as well once - // WriteRawPacket takes PacketBuffer instead of VectorisedView. - return e.lower.WriteRawPacket(vv) -} - // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 7fb8a6c49..5660418fa 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -204,7 +204,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt) views := pkt.Views() // Transmit the packet. @@ -224,21 +224,6 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - views := vv.Views() - // Transmit the packet. - e.mu.Lock() - ok := e.tx.transmit(views...) - e.mu.Unlock() - - if !ok { - return tcpip.ErrWouldBlock - } - - return nil -} - // dispatchLoop reads packets from the rx queue in a loop and dispatches them // to the network stack. func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 22d5c97f1..7131392cc 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -260,9 +260,8 @@ func TestSimpleSend(t *testing.T) { defer c.cleanup() // Prepare route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) for iters := 1000; iters > 0; iters-- { func() { @@ -342,9 +341,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) { newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) // Set both remote and local link address in route. r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - LocalLinkAddress: newLocalLinkAddress, + LocalLinkAddress: newLocalLinkAddress, } + r.ResolveWith(remoteLinkAddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ // WritePacket panics given a prependable with anything less than @@ -395,9 +394,8 @@ func TestFillTxQueue(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -444,9 +442,8 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { c.txq.rx.Flush() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -509,9 +506,8 @@ func TestFillTxMemory(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -557,9 +553,8 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index b3e8c4b92..8d9a91020 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -53,16 +53,35 @@ type endpoint struct { nested.Endpoint writer io.Writer maxPCAPLen uint32 + logPrefix string } var _ stack.GSOEndpoint = (*endpoint)(nil) var _ stack.LinkEndpoint = (*endpoint)(nil) var _ stack.NetworkDispatcher = (*endpoint)(nil) +type direction int + +const ( + directionSend = iota + directionRecv +) + // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - sniffer := &endpoint{} + return NewWithPrefix(lower, "") +} + +// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around +// another endpoint and logs packets prefixed with logPrefix as they traverse +// the endpoint. +// +// logPrefix is prepended to the log line without any separators. +// E.g. logPrefix = "NIC:en0/" will produce log lines like +// "NIC:en0/send udp [...]". +func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint { + sniffer := &endpoint{logPrefix: logPrefix} sniffer.Endpoint.Init(lower, sniffer) return sniffer } @@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, pkt) + e.dumpPacket(directionRecv, nil, protocol, pkt) e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } -func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + logPacket(e.logPrefix, dir, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Size() @@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, pkt) + e.dumpPacket(directionSend, gso, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } @@ -178,20 +197,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // forwards the request to the lower endpoint. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.dumpPacket("send", gso, protocol, pkt) + e.dumpPacket(directionSend, gso, protocol, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, protocol) } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return e.Endpoint.WriteRawPacket(vv) -} - -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -201,6 +212,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var fragmentOffset uint16 var moreFragments bool + var directionPrefix string + switch dir { + case directionSend: + directionPrefix = "send" + case directionRecv: + directionPrefix = "recv" + default: + panic(fmt.Sprintf("unrecognized direction: %d", dir)) + } + // Clone the packet buffer to not modify the original. // // We don't clone the original packet buffer so that the new packet buffer @@ -248,15 +269,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( - "%s arp %s (%s) -> %s (%s) valid:%t", + "%s%s arp %s (%s) -> %s (%s) valid:%t", prefix, + directionPrefix, tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), arp.IsValid(), ) return default: - log.Infof("%s unknown network protocol", prefix) + log.Infof("%s%s unknown network protocol", prefix, directionPrefix) return } @@ -300,7 +322,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P icmpType = "info reply" } } - log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.ICMPv6ProtocolNumber: @@ -335,7 +357,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6RedirectMsg: icmpType = "redirect message" } - log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.UDPProtocolNumber: @@ -391,7 +413,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P } default: - log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto) + log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto) return } @@ -399,5 +421,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P details += fmt.Sprintf(" gso: %+v", gso) } - log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) + log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 9a76bdba7..a364c5801 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" { return nil, false } @@ -272,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader().View().IsEmpty() { - d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt) } vv.AppendView(info.Pkt.LinkHeader().View()) } diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index ee84c3d96..9b4602c1b 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -11,7 +11,6 @@ go_library( deps = [ "//pkg/gate", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], @@ -25,7 +24,6 @@ go_test( library = ":waitable", deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index b152a0f26..cf0077f43 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -24,7 +24,6 @@ package waitable import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -132,17 +131,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, err } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if !e.writeGate.Enter() { - return nil - } - - err := e.lower.WriteRawPacket(vv) - e.writeGate.Leave() - return err -} - // WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, // and waits for inflight ones to finish before returning. func (e *Endpoint) WaitWrite() { diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 94827fc56..cf7fb5126 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -18,7 +18,6 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -81,11 +80,6 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack. return pkts.Len(), nil } -func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - e.writeCount++ - return nil -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { panic("unimplemented") diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index b38aff0b8..9ebf31b78 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -7,12 +7,14 @@ go_test( size = "small", srcs = [ "ip_test.go", + "multicast_group_test.go", ], deps = [ "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index f462524c9..0fb373612 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -319,9 +319,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { copy(h.HardwareAddressSender(), test.senderLinkAddr) copy(h.ProtocolAddressSender(), test.senderAddr) copy(h.ProtocolAddressTarget(), test.targetAddr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) if !test.isValid { // No packets should be sent after receiving an invalid ARP request. @@ -442,9 +442,9 @@ func (*testInterface) Promiscuous() bool { func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + NetProto: protocol, } + r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -557,8 +557,8 @@ func TestLinkAddressRequest(t *testing.T) { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) } rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index d8e4a3b54..429af69ee 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -18,7 +18,6 @@ go_template_instance( go_library( name = "fragmentation", srcs = [ - "frag_heap.go", "fragmentation.go", "reassembler.go", "reassembler_list.go", @@ -38,7 +37,6 @@ go_test( name = "fragmentation_test", size = "small", srcs = [ - "frag_heap_test.go", "fragmentation_test.go", "reassembler_test.go", ], diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go deleted file mode 100644 index 0b570d25a..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "fmt" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -type fragment struct { - offset uint16 - vv buffer.VectorisedView -} - -type fragHeap []fragment - -func (h *fragHeap) Len() int { - return len(*h) -} - -func (h *fragHeap) Less(i, j int) bool { - return (*h)[i].offset < (*h)[j].offset -} - -func (h *fragHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *fragHeap) Push(x interface{}) { - *h = append(*h, x.(fragment)) -} - -func (h *fragHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[:n-1] - return x -} - -// reassamble empties the heap and returns a VectorisedView -// containing a reassambled version of the fragments inside the heap. -func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { - curr := heap.Pop(h).(fragment) - views := curr.vv.Views() - size := curr.vv.Size() - - if curr.offset != 0 { - return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) - } - - for h.Len() > 0 { - curr := heap.Pop(h).(fragment) - if int(curr.offset) < size { - curr.vv.TrimFront(size - int(curr.offset)) - } else if int(curr.offset) > size { - return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) - } - size += curr.vv.Size() - views = append(views, curr.vv.Views()...) - } - return buffer.NewVectorisedView(size, views), nil -} diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go deleted file mode 100644 index 9ececcb9f..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -var reassambleTestCases = []struct { - comment string - in []fragment - want buffer.VectorisedView -}{ - { - comment: "Non-overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Non-overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Duplicated packets", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(1, "0"), - }, - { - comment: "Overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(2, "01")}, - {offset: 1, vv: vv(2, "12")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(2, "12")}, - {offset: 0, vv: vv(2, "01")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping subset in-order", - in: []fragment{ - {offset: 0, vv: vv(3, "012")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(3, "012"), - }, - { - comment: "Overlapping subset out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(3, "012")}, - }, - want: vv(3, "012"), - }, -} - -func TestReassamble(t *testing.T) { - for _, c := range reassambleTestCases { - t.Run(c.comment, func(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, err := h.reassemble() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, c.want) { - t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) - } - }) - } -} - -func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when the first packet had offset != 0") - } -} - -func TestReassambleFailsForHoles(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) - heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when there was a hole in the packet") - } -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index c75ca7d71..1af87d713 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -46,9 +46,17 @@ const ( ) var ( - // ErrInvalidArgs indicates to the caller that that an invalid argument was + // ErrInvalidArgs indicates to the caller that an invalid argument was // provided. ErrInvalidArgs = errors.New("invalid args") + + // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps + // with another one. + ErrFragmentOverlap = errors.New("overlapping fragments") + + // ErrFragmentConflict indicates that, during reassembly, some fragments are + // in conflict with one another. + ErrFragmentConflict = errors.New("conflicting fragments") ) // FragmentID is the identifier for a fragment. diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 19f4920b3..9b20bb1d8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -15,9 +15,8 @@ package fragmentation import ( - "container/heap" - "fmt" "math" + "sort" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -26,9 +25,11 @@ import ( ) type hole struct { - first uint16 - last uint16 - deleted bool + first uint16 + last uint16 + filled bool + final bool + data buffer.View } type reassembler struct { @@ -38,8 +39,7 @@ type reassembler struct { proto uint8 mu sync.Mutex holes []hole - deleted int - heap fragHeap + filled int done bool creationTime int64 pkt *stack.PacketBuffer @@ -48,49 +48,94 @@ type reassembler struct { func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, - holes: make([]hole, 0, 16), - heap: make(fragHeap, 0, 8), creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ - first: 0, - last: math.MaxUint16, - deleted: false}) + first: 0, + last: math.MaxUint16, + filled: false, + final: true, + }) return r } -// updateHoles updates the list of holes for an incoming fragment and -// returns true iff the fragment filled at least part of an existing hole. -func (r *reassembler) updateHoles(first, last uint16, more bool) bool { - used := false - for i := range r.holes { - if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first { - continue - } - used = true - r.deleted++ - r.holes[i].deleted = true - if first > r.holes[i].first { - r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false}) - } - if last < r.holes[i].last && more { - r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false}) - } - } - return used -} - func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() - consumed := 0 if r.done { // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, consumed, nil + return buffer.VectorisedView{}, 0, false, 0, nil } - if r.updateHoles(first, last, more) { + + var holeFound bool + var consumed int + for i := range r.holes { + currentHole := &r.holes[i] + + if last < currentHole.first || currentHole.last < first { + continue + } + // For IPv6, overlaps with an existing fragment are explicitly forbidden by + // RFC 8200 section 4.5: + // If any of the fragments being reassembled overlap with any other + // fragments being reassembled for the same packet, reassembly of that + // packet must be abandoned and all the fragments that have been received + // for that packet must be discarded, and no ICMP error messages should be + // sent. + // + // It is not explicitly forbidden for IPv4, but to keep parity with Linux we + // disallow it as well: + // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 + if first < currentHole.first || currentHole.last < last { + // Incoming fragment only partially fits in the free hole. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap + } + if !more { + if !currentHole.final || currentHole.filled && currentHole.last != last { + // We have another final fragment, which does not perfectly overlap. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + } + } + + holeFound = true + if currentHole.filled { + // Incoming fragment is a duplicate. + continue + } + + // We are populating the current hole with the payload and creating a new + // hole for any unfilled ranges on either end. + if first > currentHole.first { + r.holes = append(r.holes, hole{ + first: currentHole.first, + last: first - 1, + filled: false, + final: false, + }) + } + if last < currentHole.last && more { + r.holes = append(r.holes, hole{ + first: last + 1, + last: currentHole.last, + filled: false, + final: currentHole.final, + }) + currentHole.final = false + } + v := pkt.Data.ToOwnedView() + consumed = v.Size() + r.size += consumed + // Update the current hole to precisely match the incoming fragment. + r.holes[i] = hole{ + first: first, + last: last, + filled: true, + final: currentHole.final, + data: v, + } + r.filled++ // For IPv6, it is possible to have different Protocol values between // fragments of a packet (because, unlike IPv4, the Protocol is not used to // identify a fragment). In this case, only the Protocol of the first @@ -103,21 +148,30 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s r.pkt = pkt r.proto = proto } - vv := pkt.Data - // We store the incoming packet only if it filled some holes. - heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) - consumed = vv.Size() - r.size += consumed + + break + } + if !holeFound { + // Incoming fragment is beyond end. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict } - // Check if all the holes have been deleted and we are ready to reassamble. - if r.deleted < len(r.holes) { + + // Check if all the holes have been filled and we are ready to reassemble. + if r.filled < len(r.holes) { return buffer.VectorisedView{}, 0, false, consumed, nil } - res, err := r.heap.reassemble() - if err != nil { - return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err) + + sort.Slice(r.holes, func(i, j int) bool { + return r.holes[i].first < r.holes[j].first + }) + + var size int + views := make([]buffer.View, 0, len(r.holes)) + for _, hole := range r.holes { + views = append(views, hole.data) + size += hole.data.Size() } - return res, r.proto, true, consumed, nil + return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil } func (r *reassembler) checkDoneOrMark() bool { diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index a0a04a027..2ff03eeeb 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -16,92 +16,175 @@ package fragmentation import ( "math" - "reflect" "testing" + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -type updateHolesInput struct { - first uint16 - last uint16 - more bool +type processParams struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + wantDone bool + wantError error } -var holesTestCases = []struct { - comment string - in []updateHolesInput - want []hole -}{ - { - comment: "No fragments. Expected holes: {[0 -> inf]}.", - in: []updateHolesInput{}, - want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, - }, - { - comment: "One fragment at beginning. Expected holes: {[2, inf]}.", - in: []updateHolesInput{{first: 0, last: 1, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: false}, +func TestReassemblerProcess(t *testing.T) { + const proto = 99 + + v := func(size int) buffer.View { + payload := buffer.NewView(size) + for i := 1; i < size; i++ { + payload[i] = uint8(i) * 3 + } + return payload + } + + pkt := func(size int) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v(size).ToVectorisedView(), + }) + } + + var tests = []struct { + name string + params []processParams + want []hole + }{ + { + name: "No fragments", + params: nil, + want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, }, - }, - { - comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", - in: []updateHolesInput{{first: 1, last: 2, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - {first: 3, last: math.MaxUint16, deleted: false}, + { + name: "One fragment at beginning", + params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: math.MaxUint16, filled: false, final: true}, + }, }, - }, - { - comment: "One fragment at the end. Expected holes: {[0, 0]}.", - in: []updateHolesInput{{first: 1, last: 2, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, + { + name: "One fragment in the middle", + params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true, final: false, data: v(2)}, + {first: 0, last: 0, filled: false, final: false}, + {first: 3, last: math.MaxUint16, filled: false, final: true}, + }, }, - }, - { - comment: "One fragment completing a packet. Expected holes: {}.", - in: []updateHolesInput{{first: 0, last: 1, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, + { + name: "One fragment at the end", + params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true, final: true, data: v(2)}, + {first: 0, last: 0, filled: false}, + }, }, - }, - { - comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 1, more: true}, - {first: 2, last: 3, more: false}, + { + name: "One fragment completing a packet", + params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true, final: true, data: v(2)}, + }, }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: true}, + { + name: "Two fragments completing a packet", + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, + }, }, - }, - { - comment: "Two overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 2, more: true}, - {first: 2, last: 3, more: false}, + { + name: "Two fragments completing a packet with a duplicate", + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, + }, }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 3, last: math.MaxUint16, deleted: true}, + { + name: "Two fragments completing a packet with a partial duplicate", + params: []processParams{ + {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, + {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 3, filled: true, final: false, data: v(4)}, + {first: 4, last: 5, filled: true, final: true, data: v(2)}, + }, }, - }, -} + { + name: "Two overlapping fragments", + params: []processParams{ + {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, + {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, + }, + want: []hole{ + {first: 0, last: 10, filled: true, final: false, data: v(11)}, + {first: 11, last: math.MaxUint16, filled: false, final: true}, + }, + }, + { + name: "Two final fragments with different ends", + params: []processParams{ + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 10, last: 14, filled: true, final: true, data: v(5)}, + {first: 0, last: 9, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + }, + want: []hole{ + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate, with different ends", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, + }, + }, + } -func TestUpdateHoles(t *testing.T) { - for _, c := range holesTestCases { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - for _, i := range c.in { - r.updateHoles(i.first, i.last, i.more) - } - if !reflect.DeepEqual(r.holes, c.want) { - t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + for _, param := range test.params { + _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + if done != param.wantDone || err != param.wantError { + t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) + } + } + if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD new file mode 100644 index 000000000..ca1247c1e --- /dev/null +++ b/pkg/tcpip/network/ip/BUILD @@ -0,0 +1,26 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ip", + srcs = ["generic_multicast_protocol.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + ], +) + +go_test( + name = "ip_test", + size = "small", + srcs = ["generic_multicast_protocol_test.go"], + deps = [ + ":ip", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/faketime", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go new file mode 100644 index 000000000..f2f0e069c --- /dev/null +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -0,0 +1,676 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ip holds IPv4/IPv6 common utilities. +package ip + +import ( + "fmt" + "math/rand" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// hostState is the state a host may be in for a multicast group. +type hostState int + +// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 +// (RFC 2710 section 5). Even though the states are generic across both IGMPv2 +// and MLDv1, IGMPv2 terminology will be used. +// +// ______________receive query______________ +// | | +// | _____send or receive report_____ | +// | | | | +// V | V | +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | +// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ +// | ^ | ^ | ^ | ^ +// | | | | | | | | +// ---------- ------- ---------- ------------- +// initialize new send inital fail to send send or receive +// group membership report delayed report report +// +// Not shown in the diagram above, but any state may transition into the non +// member state when a group is left. +const ( + // nonMember is the "'Non-Member' state, when the host does not belong to the + // group on the interface. This is the initial state for all memberships on + // all network interfaces; it requires no storage in the host." + // + // 'Non-Listener' is the MLDv1 term used to describe this state. + // + // This state is used to keep track of groups that have been joined locally, + // but without advertising the membership to the network. + nonMember hostState = iota + + // pendingMember is a newly joined member that is waiting to successfully send + // the initial set of reports. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the initial report needs to be sent. + // + // MAY NOT transition to the idle member state from this state. + pendingMember + + // delayingMember is the "'Delaying Member' state, when the host belongs to + // the group on the interface and has a report delay timer running for that + // membership." + // + // 'Delaying Listener' is the MLDv1 term used to describe this state. + delayingMember + + // queuedDelayingMember is a delayingMember that failed to send a report after + // its delayed report timer fired. Hosts in this state are waiting to attempt + // retransmission of the delayed report. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the delayed report needs to be sent. + // + // May transition to idle member if a report is received for a group. + queuedDelayingMember + + // idleMember is the "Idle Member" state, when the host belongs to the group + // on the interface and does not have a report delay timer running for that + // membership. + // + // 'Idle Listener' is the MLDv1 term used to describe this state. + idleMember +) + +func (s hostState) isDelayingMember() bool { + switch s { + case nonMember, pendingMember, idleMember: + return false + case delayingMember, queuedDelayingMember: + return true + default: + panic(fmt.Sprintf("unrecognized host state = %d", s)) + } +} + +// multicastGroupState holds the Generic Multicast Protocol state for a +// multicast group. +type multicastGroupState struct { + // joins is the number of times the group has been joined. + joins uint64 + + // state holds the host's state for the group. + state hostState + + // lastToSendReport is true if we sent the last report for the group. It is + // used to track whether there are other hosts on the subnet that are also + // members of the group. + // + // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page + // 8 for MLDv1. + lastToSendReport bool + + // delayedReportJob is used to delay sending responses to membership report + // messages in order to reduce duplicate reports from multiple hosts on the + // interface. + // + // Must not be nil. + delayedReportJob *tcpip.Job +} + +// GenericMulticastProtocolOptions holds options for the generic multicast +// protocol. +type GenericMulticastProtocolOptions struct { + // Rand is the source of random numbers. + Rand *rand.Rand + + // Clock is the clock used to create timers. + Clock tcpip.Clock + + // Protocol is the implementation of the variant of multicast group protocol + // in use. + Protocol MulticastGroupProtocol + + // MaxUnsolicitedReportDelay is the maximum amount of time to wait between + // transmitting unsolicited reports. + // + // Unsolicited reports are transmitted when a group is newly joined. + MaxUnsolicitedReportDelay time.Duration + + // AllNodesAddress is a multicast address that all nodes on a network should + // be a member of. + // + // This address will not have the generic multicast protocol performed on it; + // it will be left in the non member/listener state, and packets will never + // be sent for it. + AllNodesAddress tcpip.Address +} + +// MulticastGroupProtocol is a multicast group protocol whose core state machine +// can be represented by GenericMulticastProtocolState. +type MulticastGroupProtocol interface { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled() bool + + // SendReport sends a multicast report for the specified group address. + // + // Returns false if the caller should queue the report to be sent later. Note, + // returning false does not mean that the receiver hit an error. + SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) + + // SendLeave sends a multicast leave for the specified group address. + SendLeave(groupAddress tcpip.Address) *tcpip.Error +} + +// GenericMulticastProtocolState is the per interface generic multicast protocol +// state. +// +// There is actually no protocol named "Generic Multicast Protocol". Instead, +// the term used to refer to a generic multicast protocol that applies to both +// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state +// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. +// +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// +// GenericMulticastProtocolState.Init MUST be called before calling any of +// the methods on GenericMulticastProtocolState. +// +// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the +// multicast group protocol is disabled so that leave messages may be sent. +type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + + opts GenericMulticastProtocolOptions + + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState + + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex +} + +// Init initializes the Generic Multicast Protocol state. +// +// Must only be called once for the lifetime of g; Init will panic if it is +// called twice. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +// +// Note: the methods on opts.Protocol will always be called while protocolMU is +// held. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + + *g = GenericMulticastProtocolState{ + opts: opts, + memberships: make(map[tcpip.Address]multicastGroupState), + protocolMU: protocolMU, + } +} + +// MakeAllNonMemberLocked transitions all groups to the non-member state. +// +// The groups will still be considered joined locally. +// +// MUST be called when the multicast group protocol is disabled. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { + if !g.opts.Protocol.Enabled() { + return + } + + for groupAddress, info := range g.memberships { + g.transitionToNonMemberLocked(groupAddress, &info) + g.memberships[groupAddress] = info + } +} + +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. +// +// Must only be called after calling MakeAllNonMember as a group should not be +// initialized while it is not in the non-member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { + if !g.opts.Protocol.Enabled() { + return + } + + for groupAddress, info := range g.memberships { + g.initializeNewMemberLocked(groupAddress, &info) + g.memberships[groupAddress] = info + } +} + +// SendQueuedReportsLocked attempts to send reports for groups that failed to +// send reports during their last attempt. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { + for groupAddress, info := range g.memberships { + switch info.state { + case nonMember, delayingMember, idleMember: + case pendingMember: + // pendingMembers failed to send their initial unsolicited report so try + // to send the report and queue the extra unsolicited reports. + g.maybeSendInitialReportLocked(groupAddress, &info) + case queuedDelayingMember: + // queuedDelayingMembers failed to send their delayed reports so try to + // send the report and transition them to the idle state. + g.maybeSendDelayedReportLocked(groupAddress, &info) + default: + panic(fmt.Sprintf("unrecognized host state = %d", info.state)) + } + g.memberships[groupAddress] = info + } +} + +// JoinGroupLocked handles joining a new group. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) { + if info, ok := g.memberships[groupAddress]; ok { + // The group has already been joined. + info.joins++ + g.memberships[groupAddress] = info + return + } + + info := multicastGroupState{ + // Since we just joined the group, its count is 1. + joins: 1, + // The state will be updated below, if required. + state: nonMember, + lastToSendReport: false, + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + if !g.opts.Protocol.Enabled() { + panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress)) + } + + info, ok := g.memberships[groupAddress] + if !ok { + panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) + } + + g.maybeSendDelayedReportLocked(groupAddress, &info) + g.memberships[groupAddress] = info + }), + } + + if g.opts.Protocol.Enabled() { + g.initializeNewMemberLocked(groupAddress, &info) + } + + g.memberships[groupAddress] = info +} + +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] + return ok +} + +// LeaveGroupLocked handles leaving the group. +// +// Returns false if the group is not currently joined. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] + if !ok { + return false + } + + if info.joins == 0 { + panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress)) + } + info.joins-- + if info.joins != 0 { + // If we still have outstanding joins, then do nothing further. + g.memberships[groupAddress] = info + return true + } + + g.transitionToNonMemberLocked(groupAddress, &info) + delete(g.memberships, groupAddress) + return true +} + +// HandleQueryLocked handles a query message with the specified maximum response +// time. +// +// If the group address is unspecified, then reports will be scheduled for all +// joined groups. +// +// Report(s) will be scheduled to be sent after a random duration between 0 and +// the maximum response time. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { + if !g.opts.Protocol.Enabled() { + return + } + + // As per RFC 2236 section 2.4 (for IGMPv2), + // + // In a Membership Query message, the group address field is set to zero + // when sending a General Query, and set to the group address being + // queried when sending a Group-Specific Query. + // + // As per RFC 2710 section 3.6 (for MLDv1), + // + // In a Query message, the Multicast Address field is set to zero when + // sending a General Query, and set to a specific IPv6 multicast address + // when sending a Multicast-Address-Specific Query. + if groupAddress.Unspecified() { + // This is a general query as the group address is unspecified. + for groupAddress, info := range g.memberships { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.memberships[groupAddress] = info + } + } else if info, ok := g.memberships[groupAddress]; ok { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.memberships[groupAddress] = info + } +} + +// HandleReportLocked handles a report message. +// +// If the report is for a joined group, any active delayed report will be +// cancelled and the host state for the group transitions to idle. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { + if !g.opts.Protocol.Enabled() { + return + } + + // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), + // + // If the host receives another host's Report (version 1 or 2) while it has + // a timer running, it stops its timer for the specified group and does not + // send a Report + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // If a node receives another node's Report from an interface for a + // multicast address while it has a timer running for that same address + // on that interface, it stops its timer and does not send a Report for + // that address, thus suppressing duplicate reports on the link. + if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { + info.delayedReportJob.Cancel() + info.lastToSendReport = false + info.state = idleMember + g.memberships[groupAddress] = info + } +} + +// initializeNewMemberLocked initializes a new group membership. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != nonMember { + panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) + } + + info.lastToSendReport = false + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + info.state = idleMember + return + } + + info.state = pendingMember + g.maybeSendInitialReportLocked(groupAddress, info) +} + +// maybeSendInitialReportLocked attempts to start transmission of the initial +// set of reports after newly joining a group. +// +// Host must be in pending member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != pendingMember { + panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) + } + + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a host joins a multicast group, it should immediately transmit an + // unsolicited Version 2 Membership Report for that group" ... "it is + // recommended that it be repeated". + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node starts listening to a multicast address on an interface, + // it should immediately transmit an unsolicited Report for that address + // on that interface, in case it is the first listener on the link. To + // cover the possibility of the initial Report being lost or damaged, it + // is recommended that it be repeated once or twice after short delays + // [Unsolicited Report Interval]. + // + // TODO(gvisor.dev/issue/4901): Support a configurable number of initial + // unsolicited reports. + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + } +} + +// maybeSendDelayedReportLocked attempts to send the delayed report. +// +// Host must be in pending, delaying or queued delaying member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if !info.state.isDelayingMember() { + panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) + } + + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + info.state = idleMember + } else { + info.state = queuedDelayingMember + } +} + +// maybeSendLeave attempts to send a leave message. +func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { + if !g.opts.Protocol.Enabled() || !lastToSendReport { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // Okay to ignore the error here as if packet write failed, the multicast + // routers will eventually drop our membership anyways. If the interface is + // being disabled or removed, the generic multicast protocol's should be + // cleared eventually. + // + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a router receives a Report, it adds the group being reported to + // the list of multicast group memberships on the network on which it + // received the Report and sets the timer for the membership to the + // [Group Membership Interval]. Repeated Reports refresh the timer. If + // no Reports are received for a particular group before this timer has + // expired, the router assumes that the group has no local members and + // that it need not forward remotely-originated multicasts for that + // group onto the attached network. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // When a router receives a Report from a link, if the reported address + // is not already present in the router's list of multicast address + // having listeners on that link, the reported address is added to the + // list, its timer is set to [Multicast Listener Interval], and its + // appearance is made known to the router's multicast routing component. + // If a Report is received for a multicast address that is already + // present in the router's list, the timer for that address is reset to + // [Multicast Listener Interval]. If an address's timer expires, it is + // assumed that there are no longer any listeners for that address + // present on the link, so it is deleted from the list and its + // disappearance is made known to the multicast routing component. + // + // The requirement to send a leave message is also optional (it MAY be + // skipped): + // + // As per RFC 2236 section 6 page 8 (for IGMPv2), + // + // "send leave" for the group on the interface. If the interface + // state says the Querier is running IGMPv1, this action SHOULD be + // skipped. If the flag saying we were the last host to report is + // cleared, this action MAY be skipped. The Leave Message is sent to + // the ALL-ROUTERS group (224.0.0.2). + // + // As per RFC 2710 section 5 page 8 (for MLDv1), + // + // "send done" for the address on the interface. If the flag saying + // we were the last node to report is cleared, this action MAY be + // skipped. The Done message is sent to the link-scope all-routers + // address (FF02::2). + _ = g.opts.Protocol.SendLeave(groupAddress) +} + +// transitionToNonMemberLocked transitions the given multicast group the the +// non-member/listener state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state == nonMember { + return + } + + info.delayedReportJob.Cancel() + g.maybeSendLeave(groupAddress, info.lastToSendReport) + info.lastToSendReport = false + info.state = nonMember +} + +// setDelayTimerForAddressRLocked sets timer to send a delay report. +// +// Precondition: g.protocolMU MUST be read locked. +func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { + if info.state == nonMember { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // If a timer for the group is already unning, it is reset to the random + // value only if the requested Max Response Time is less than the remaining + // value of the running timer. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // If a timer for any address is already running, it is reset to the new + // random value only if the requested Maximum Response Delay is less than + // the remaining value of the running timer. + if info.state == delayingMember { + // TODO: Reset the timer if time remaining is greater than maxResponseTime. + return + } + + info.state = delayingMember + info.delayedReportJob.Cancel() + info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) +} + +// calculateDelayTimerDuration returns a random time between (0, maxRespTime]. +func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration { + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // When a host receives a Group-Specific Query, it sets a delay timer to a + // random value selected from the range (0, Max Response Time]... + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node receives a Multicast-Address-Specific Query, if it is + // listening to the queried Multicast Address on the interface from + // which the Query was received, it sets a delay timer for that address + // to a random value selected from the range [0, Maximum Response Delay], + // as above. + if maxRespTime == 0 { + return 0 + } + return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime))) +} diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go new file mode 100644 index 000000000..f56f7aa90 --- /dev/null +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -0,0 +1,877 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "math/rand" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" +) + +const ( + addr1 = tcpip.Address("\x01") + addr2 = tcpip.Address("\x02") + addr3 = tcpip.Address("\x03") + addr4 = tcpip.Address("\x04") + + maxUnsolicitedReportDelay = time.Second +) + +var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) + +type mockMulticastGroupProtocol struct { + t *testing.T + + mu sync.RWMutex + + // Must only be accessed with mu held. + sendReportGroupAddrCount map[tcpip.Address]int + + // Must only be accessed with mu held. + sendLeaveGroupAddrCount map[tcpip.Address]int + + // Must only be accessed with mu held. + makeQueuePackets bool + + // Must only be accessed with mu held. + disabled bool +} + +func (m *mockMulticastGroupProtocol) init() { + m.mu.Lock() + defer m.mu.Unlock() + m.initLocked() +} + +func (m *mockMulticastGroupProtocol) initLocked() { + m.sendReportGroupAddrCount = make(map[tcpip.Address]int) + m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +} + +func (m *mockMulticastGroupProtocol) setEnabled(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.disabled = !v +} + +// Enabled implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be read locked. +func (m *mockMulticastGroupProtocol) Enabled() bool { + return !m.disabled +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + + m.sendReportGroupAddrCount[groupAddress]++ + return !m.makeQueuePackets, nil +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. +func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + + m.sendLeaveGroupAddrCount[groupAddress]++ + return nil +} + +func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendReportGroupAddrCount := make(map[tcpip.Address]int) + for _, a := range sendReportGroupAddresses { + sendReportGroupAddrCount[a] = 1 + } + + sendLeaveGroupAddrCount := make(map[tcpip.Address]int) + for _, a := range sendLeaveGroupAddresses { + sendLeaveGroupAddrCount[a] = 1 + } + + diff := cmp.Diff( + &mockMulticastGroupProtocol{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, + m, + cmp.AllowUnexported(mockMulticastGroupProtocol{}), + // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t + cmp.FilterPath( + func(p cmp.Path) bool { + switch p.Last().String() { + case ".mu", ".t", ".makeQueuePackets", ".disabled": + return true + } + return false + }, + cmp.Ignore(), + ), + ) + m.initLocked() + return diff +} + +func TestJoinGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendReports bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendReports: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendReports: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(0)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + // Joining a group should send a report immediately and another after + // a random interval between 0 and the maximum unsolicited report delay. + mgp.mu.Lock() + g.JoinGroupLocked(test.addr) + mgp.mu.Unlock() + if test.shouldSendReports { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLeaveGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendMessages bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendMessages: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendMessages: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(1)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + mgp.mu.Lock() + g.JoinGroupLocked(test.addr) + mgp.mu.Unlock() + if test.shouldSendMessages { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Leaving a group should send a leave report immediately and cancel any + // delayed reports. + { + mgp.mu.Lock() + res := g.LeaveGroupLocked(test.addr) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr) + } + } + if test.shouldSendMessages { + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleReport(t *testing.T) { + tests := []struct { + name string + reportAddr tcpip.Address + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + reportAddr: "", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + reportAddr: "\x00", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + reportAddr: addr1, + expectReportsFor: []tcpip.Address{addr2}, + }, + { + name: "Specified all-nodes", + reportAddr: addr3, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified other", + reportAddr: addr4, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(2)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.mu.Lock() + g.JoinGroupLocked(addr1) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr2) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr3) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report for a group we have a timer scheduled for should + // cancel our delayed report timer for the group. + mgp.mu.Lock() + g.HandleReportLocked(test.reportAddr) + mgp.mu.Unlock() + if len(test.expectReportsFor) != 0 { + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleQuery(t *testing.T) { + tests := []struct { + name string + queryAddr tcpip.Address + maxDelay time.Duration + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + queryAddr: "", + maxDelay: 0, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + queryAddr: "\x00", + maxDelay: 1, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + queryAddr: addr1, + maxDelay: 2, + expectReportsFor: []tcpip.Address{addr1}, + }, + { + name: "Specified all-nodes", + queryAddr: addr3, + maxDelay: 3, + expectReportsFor: nil, + }, + { + name: "Specified other", + queryAddr: addr4, + maxDelay: 4, + expectReportsFor: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.mu.Lock() + g.JoinGroupLocked(addr1) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr2) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr3) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should make us schedule a new delayed report if it + // is a query directed at us or a general query. + mgp.mu.Lock() + g.HandleQueryLocked(test.queryAddr, test.maxDelay) + mgp.mu.Unlock() + if len(test.expectReportsFor) != 0 { + clock.Advance(test.maxDelay) + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestJoinCount(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: time.Second, + }) + + // Set the join count to 2 for a group. + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + // Only the first join should trigger a report to be sent. + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Group should still be considered joined after leaving once. + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) + } + if !isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + // A leave report should only be sent once the join count reaches 0. + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Leaving once more should actually remove us from the group. + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) + } + if isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Group should no longer be joined so we should not have anything to + // leave. + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1) + } + if isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestMakeAllNonMemberAndInitialize(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.mu.Lock() + g.JoinGroupLocked(addr1) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr2) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr3) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should send the leave reports for each but still consider them locally + // joined. + mgp.mu.Lock() + g.MakeAllNonMemberLocked() + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + for _, group := range []tcpip.Address{addr1, addr2, addr3} { + mgp.mu.RLock() + res := g.IsLocallyJoinedRLocked(group) + mgp.mu.RUnlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group) + } + } + + // Should send the initial set of unsolcited reports. + mgp.mu.Lock() + g.InitializeGroupsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +// TestGroupStateNonMember tests that groups do not send packets when in the +// non-member state, but are still considered locally joined. +func TestGroupStateNonMember(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + mgp.setEnabled(false) + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + // Joining groups should not send any reports. + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + { + mgp.mu.Lock() + g.JoinGroupLocked(addr2) + res := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should not send any reports. + mgp.mu.Lock() + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Nanosecond) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Leaving groups should not send any leave messages. + { + mgp.mu.Lock() + addr2LeaveRes := g.LeaveGroupLocked(addr2) + addr1IsJoined := g.IsLocallyJoinedRLocked(addr1) + addr2IsJoined := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !addr2LeaveRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2) + } + if !addr1IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + if addr2IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestQueuedPackets(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + // Joining should trigger a SendReport, but mgp should report that we did not + // send the packet. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.JoinGroupLocked(addr1) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report timer should have been cancelled since we did not send + // the initial report earlier. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to successfully send the report. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send (we should be idle). + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query but mock being unable to send reports again. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to send reports again - we should have a packet queued to + // send. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query again, but mock being unable to send reports. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report should should transition us into the idle member state, + // even if we had a packet queued. We should no longer have any packets to + // send. + mgp.mu.Lock() + g.HandleReportLocked(addr1) + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // When we fail to send the initial set of reports, incoming reports should + // not affect a newly joined group's reports from being sent. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.JoinGroupLocked(addr2) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.HandleReportLocked(addr2) + // Attempting to send queued reports while still unable to send reports should + // not change the host state. + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Mock being able to successfully send the report. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index d49c44846..3005973d7 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -193,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu panic("not implemented") } -func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - return tcpip.ErrNotSupported -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*testObject) ARPHardwareType() header.ARPHardwareType { panic("not implemented") @@ -207,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net panic("not implemented") } -func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { +func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -223,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } -func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { +func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -348,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -554,7 +550,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -623,11 +619,11 @@ func TestReceive(t *testing.T) { view := buffer.NewView(header.IPv6MinimumSize + payloadLen) ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadLen, - NextHeader: 10, - HopLimit: ipv6.DefaultTTL, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, + PayloadLength: payloadLen, + TransportProtocol: 10, + HopLimit: ipv6.DefaultTTL, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, }) // Make payload be non-zero. @@ -937,7 +933,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -997,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) { // Create the outer IPv6 header. ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIPv6Addr, + PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 20, + SrcAddr: outerSrcAddr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -1011,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp.SetIdent(0xdead) icmp.SetSequence(0xbeef) - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - NextHeader: 10, - HopLimit: 20, - SrcAddr: localIPv6Addr, - DstAddr: remoteIPv6Addr, - }) - + var extHdrs header.IPv6ExtHdrSerializer // Build the fragmentation header if needed. if c.fragmentOffset != nil { - ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) - frag.Encode(&header.IPv6FragmentFields{ - NextHeader: 10, + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ FragmentOffset: *c.fragmentOffset, M: true, Identification: 0x12345678, }) } + // Create the inner IPv6 header. + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) + ip.Encode(&header.IPv6Fields{ + PayloadLength: 100, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, + ExtensionHeaders: extHdrs, + }) + // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { view[i] = uint8(i) @@ -1093,7 +1088,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { dataBuf := [dataLen]byte{1, 2, 3, 4} data := dataBuf[:] - ipv4Options := header.IPv4Options{0, 1, 0, 1} + ipv4Options := header.IPv4OptionsSerializer{ + &header.IPv4SerializableListEndOption{}, + &header.IPv4SerializableNOPOption{}, + &header.IPv4SerializableListEndOption{}, + &header.IPv4SerializableNOPOption{}, + } + + expectOptions := header.IPv4Options{ + byte(header.IPv4OptionListEndType), + byte(header.IPv4OptionNOPType), + byte(header.IPv4OptionListEndType), + byte(header.IPv4OptionNOPType), + } ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] @@ -1243,7 +1250,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding() + ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1266,7 +1273,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { netHdr := pkt.NetworkHeader() - hdrLen := header.IPv4MinimumSize + len(ipv4Options) + hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) if len(netHdr.View()) != hdrLen { t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) } @@ -1276,7 +1283,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { checker.DstAddr(remoteIPv4Addr), checker.IPv4HeaderLength(hdrLen), checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(ipv4Options), + checker.IPv4Options(expectOptions), checker.IPPayload(data), ) }, @@ -1288,7 +1295,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding())) + ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) ip.Encode(&header.IPv4Fields{ Protocol: transportProto, TTL: ipv4.DefaultTTL, @@ -1307,7 +1314,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { netHdr := pkt.NetworkHeader() - hdrLen := header.IPv4MinimumSize + len(ipv4Options) + hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) if len(netHdr.View()) != hdrLen { t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) } @@ -1317,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { checker.DstAddr(remoteIPv4Addr), checker.IPv4HeaderLength(hdrLen), checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(ipv4Options), + checker.IPv4Options(expectOptions), checker.IPPayload(data), ) }, @@ -1336,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return hdr.View().ToVectorisedView() }, @@ -1379,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + // NB: we're lying about transport protocol here to verify the raw + // fragment header bytes. + TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return hdr.View().ToVectorisedView() }, @@ -1414,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return buffer.View(ip).ToVectorisedView() }, @@ -1449,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 6252614ec..32f53f217 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -6,6 +6,7 @@ go_library( name = "ipv4", srcs = [ "icmp.go", + "igmp.go", "ipv4.go", ], visibility = ["//visibility:public"], @@ -17,6 +18,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -24,7 +26,10 @@ go_library( go_test( name = "ipv4_test", size = "small", - srcs = ["ipv4_test.go"], + srcs = [ + "igmp_test.go", + "ipv4_test.go", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 488945226..8e392f86c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -63,7 +63,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { stats := e.protocol.stack.Stats() - received := stats.ICMP.V4PacketsReceived + received := stats.ICMP.V4.PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. @@ -130,7 +130,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.Echo.Increment() - sent := stats.ICMP.V4PacketsSent + sent := stats.ICMP.V4.PacketsSent if !e.protocol.stack.AllowICMPMessage() { sent.RateLimited.Increment() return @@ -379,7 +379,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - sent := p.stack.Stats().ICMP.V4PacketsSent + sent := p.stack.Stats().ICMP.V4.PacketsSent if !p.stack.AllowICMPMessage() { sent.RateLimited.Increment() return nil diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go new file mode 100644 index 000000000..da88d65d1 --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -0,0 +1,345 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "fmt" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // igmpV1PresentDefault is the initial state for igmpV1Present in the + // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is + // the initial state." + igmpV1PresentDefault = 0 + + // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18 + // See note on igmpState.igmpV1Present for more detail. + v1RouterPresentTimeout = 400 * time.Second + + // v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router + // will send General Queries with the Max Response Time set to 0. This MUST + // be interpreted as a value of 100 (10 seconds)." + // + // Note that the Max Response Time field is a value in units of deciseconds. + v1MaxRespTime = 10 * time.Second + + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited IGMP reports. + // + // Obtained from RFC 2236 Section 8.10, Page 19. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// IGMPOptions holds options for IGMP. +type IGMPOptions struct { + // Enabled indicates whether IGMP will be performed. + // + // When enabled, IGMP may transmit IGMP report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // IGMP packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*igmpState)(nil) + +// igmpState is the per-interface IGMP state. +// +// igmpState.init() MUST be called after creating an IGMP state. +type igmpState struct { + // The IPv4 endpoint this igmpState is for. + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState + + // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from + // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 + // Membership Reports in response to its Queries, and will not pay + // attention to Version 2 Membership Reports. Therefore, a state variable + // MUST be kept for each interface, describing whether the multicast + // Querier on that interface is running IGMPv1 or IGMPv2. This variable + // MUST be based upon whether or not an IGMPv1 query was heard in the last + // [Version 1 Router Present Timeout] seconds". + // + // Must be accessed with atomic operations. Holds a value of 1 when true, 0 + // when false. + igmpV1Present uint32 + + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job +} + +// Enabled implements ip.MulticastGroupProtocol. +func (igmp *igmpState) Enabled() bool { + // No need to perform IGMP on loopback interfaces since they don't have + // neighbouring nodes. + return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled() +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + igmpType := header.IGMPv2MembershipReport + if igmp.v1Present() { + igmpType = header.IGMPv1MembershipReport + } + return igmp.writePacket(groupAddress, groupAddress, igmpType) +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + // As per RFC 2236 Section 6, Page 8: "If the interface state says the + // Querier is running IGMPv1, this action SHOULD be skipped. If the flag + // saying we were the last host to report is cleared, this action MAY be + // skipped." + if igmp.v1Present() { + return nil + } + _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + return err +} + +// init sets up an igmpState struct, and is required to be called before using +// a new igmpState. +// +// Must only be called once for the lifetime of igmp. +func (igmp *igmpState) init(ep *endpoint) { + igmp.ep = ep + igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: igmp, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv4AllSystems, + }) + igmp.igmpV1Present = igmpV1PresentDefault + igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { + igmp.setV1Present(false) + }) +} + +// handleIGMP handles an IGMP packet. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { + stats := igmp.ep.protocol.stack.Stats() + received := stats.IGMP.PacketsReceived + headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + h := header.IGMP(headerView) + + // Temporarily reset the checksum field to 0 in order to calculate the proper + // checksum. + wantChecksum := h.Checksum() + h.SetChecksum(0) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(wantChecksum) + + if gotChecksum != wantChecksum { + received.ChecksumErrors.Increment() + return + } + + switch h.Type() { + case header.IGMPMembershipQuery: + received.MembershipQuery.Increment() + if len(headerView) < header.IGMPQueryMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) + case header.IGMPv1MembershipReport: + received.V1MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPv2MembershipReport: + received.V2MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPLeaveGroup: + received.LeaveGroup.Increment() + // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or + // Report, are ignored in all states" + + default: + // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should + // be silently ignored. New message types may be used by newer versions of + // IGMP, by multicast routing protocols, or other uses." + received.Unrecognized.Increment() + } +} + +func (igmp *igmpState) v1Present() bool { + return atomic.LoadUint32(&igmp.igmpV1Present) == 1 +} + +func (igmp *igmpState) setV1Present(v bool) { + if v { + atomic.StoreUint32(&igmp.igmpV1Present, 1) + } else { + atomic.StoreUint32(&igmp.igmpV1Present, 0) + } +} + +// handleMembershipQuery handles a membership query. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { + // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero + // then change the state to note that an IGMPv1 router is present and + // schedule the query received Job. + if maxRespTime == 0 && igmp.Enabled() { + igmp.igmpV1Job.Cancel() + igmp.igmpV1Job.Schedule(v1RouterPresentTimeout) + igmp.setV1Present(true) + maxRespTime = v1MaxRespTime + } + + igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime) +} + +// handleMembershipReport handles a membership report. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { + igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) +} + +// writePacket assembles and sends an IGMP packet. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { + igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) + igmpData.SetType(igmpType) + igmpData.SetGroupAddress(groupAddress) + igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()), + Data: buffer.View(igmpData).ToVectorisedView(), + }) + + addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */) + if addressEndpoint == nil { + return false, nil + } + localAddr := addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() + addressEndpoint = nil + igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.IGMPProtocolNumber, + TTL: header.IGMPTTL, + TOS: stack.DefaultTOS, + }, header.IPv4OptionsSerializer{ + &header.IPv4SerializableRouterAlertOption{}, + }) + + sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sentStats.Dropped.Increment() + return false, err + } + switch igmpType { + case header.IGMPv1MembershipReport: + sentStats.V1MembershipReport.Increment() + case header.IGMPv2MembershipReport: + sentStats.V2MembershipReport.Increment() + case header.IGMPLeaveGroup: + sentStats.LeaveGroup.Increment() + default: + panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) + } + return true, nil +} + +// joinGroup handles adding a new group to the membership map, setting up the +// IGMP state for the group, and sending and scheduling the required +// messages. +// +// If the group already exists in the membership map, returns +// tcpip.ErrDuplicateAddress. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { + igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress) +} + +// isInGroup returns true if the specified group has been joined locally. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { + return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Leave Group message +// if required. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { + return nil + } + + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of IGMP, but remains +// joined locally. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) softLeaveAll() { + igmp.genericMulticastProtocol.MakeAllNonMemberLocked() +} + +// initializeAll attemps to initialize the IGMP state for each group that has +// been joined locally. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) initializeAll() { + igmp.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) sendQueuedReports() { + igmp.genericMulticastProtocol.SendQueuedReportsLocked() +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go new file mode 100644 index 000000000..1ee573ac8 --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -0,0 +1,215 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + addr = tcpip.Address("\x0a\x00\x00\x01") + multicastAddr = tcpip.Address("\xe0\x00\x00\x03") + nicID = 1 +) + +// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. Raises a t.Error if +// any field does not match. +func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.SrcAddr(addr), + checker.DstAddr(remoteAddress), + // TTL for an IGMP message must be 1 as per RFC 2236 section 2. + checker.TTL(1), + checker.IPv4RouterAlert(), + checker.IGMP( + checker.IGMPType(igmpType), + checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + // Create an endpoint of queue size 1, since no more than 1 packets are ever + // queued in the tests in this file. + e := channel.New(1, 1280, linkAddr) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMP: ipv4.IGMPOptions{ + Enabled: igmpEnabled, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + return e, s, clock +} + +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: 1, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(igmpType) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is +// present on the network. The IGMP stack will then send IGMPv1 Membership +// reports for backwards compatibility. +func TestIgmpV1Present(t *testing.T) { + e, s, clock := createStack(t, true) + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // This NIC will send an IGMPv2 report immediately, before this test can get + // the IGMPv1 General Membership Query in. + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Inject an IGMPv1 General Membership Query which is identical to a standard + // membership query except the Max Response Time is set to 0, which will tell + // the stack that this is a router using IGMPv1. Send it to the all systems + // group which is the only group this host belongs to. + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems) + if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { + t.Fatalf("got Membership Queries received = %d, want = 1", got) + } + + // Before advancing the clock, verify that this host has not sent a + // V1MembershipReport yet. + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) + } + + // Verify the solicited Membership Report is sent. Now that this NIC has seen + // an IGMPv1 query, it should send an IGMPv1 Membership Report. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) +} + +func TestSendQueuedIGMPReports(t *testing.T) { + e, s, clock := createStack(t, true) + + // Joining a group without an assigned address should queue IGMP packets; none + // should be sent without an assigned address. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) + } + reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport + if got := reportStat.Value(); got != 0 { + t.Errorf("got reportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } + + // The initial set of IGMP reports that were queued should be sent once an + // address is assigned. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + if got := reportStat.Value(); got != 1 { + t.Errorf("got reportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + if got := reportStat.Value(); got != 2 { + t.Errorf("got reportStat.Value() = %d, want = 2", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + + // Should have no more packets to send after the initial set of unsolicited + // reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1efe6297a..e9ff70d04 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -83,6 +83,7 @@ type endpoint struct { sync.RWMutex addressableEndpointState stack.AddressableEndpointState + igmp igmpState } } @@ -93,7 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) + e.mu.igmp.init(e) + e.mu.Unlock() return e } @@ -121,11 +125,22 @@ func (e *endpoint) Enable() *tcpip.Error { // We have no need for the address endpoint. ep.DecRef() + // Groups may have been joined while the endpoint was disabled, or the + // endpoint may have left groups from the perspective of IGMP when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + e.mu.igmp.initializeAll() + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the // all-systems multicast group. - _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) - return err + if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err)) + } + + return nil } // Enabled implements stack.NetworkEndpoint. @@ -157,19 +172,27 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.isEnabled() { return } // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } + // Leave groups from the perspective of IGMP so that routers know that + // we are no longer interested in the group. + e.mu.igmp.softLeaveAll() + // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // DefaultTTL is the default time-to-live value for this endpoint. @@ -198,37 +221,34 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) { hdrLen := header.IPv4MinimumSize - var opts header.IPv4Options - if params.Options != nil { - var ok bool - if opts, ok = params.Options.(header.IPv4Options); !ok { - panic(fmt.Sprintf("want IPv4Options, got %T", params.Options)) - } - hdrLen += opts.SizeWithPadding() - if hdrLen > header.IPv4MaximumHeaderSize { - // Since we have no way to report an error we must either panic or create - // a packet which is different to what was requested. Choose panic as this - // would be a programming error that should be caught in testing. - panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize)) - } + var optLen int + if options != nil { + optLen = int(options.Length()) + } + hdrLen += optLen + if hdrLen > header.IPv4MaximumHeaderSize { + // Since we have no way to report an error we must either panic or create + // a packet which is different to what was requested. Choose panic as this + // would be a programming error that should be caught in testing. + panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize)) } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := uint16(pkt.Size()) // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. - id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ TotalLength: length, ID: uint16(id), TTL: params.TTL, TOS: params.TOS, Protocol: uint8(params.Protocol), - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - Options: opts, + SrcAddr: srcAddr, + DstAddr: dstAddr, + Options: options, }) ip.SetChecksum(^ip.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber @@ -259,7 +279,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) // iptables filtering. All packets that reach here are locally // generated. @@ -347,7 +367,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) @@ -461,7 +481,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } @@ -566,21 +586,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { stats.IP.MalformedPacketsReceived.Increment() return } - srcAddr := h.SourceAddress() - dstAddr := h.DestinationAddress() - - addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) - if addressEndpoint == nil { - if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - _ = e.forwardPacket(pkt) - return - } - subnet := addressEndpoint.AddressWithPrefix().Subnet() - addressEndpoint.DecRef() // There has been some confusion regarding verifying checksums. We need // just look for negative 0 (0xffff) as the checksum, as it's not possible to @@ -608,16 +613,42 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + // As per RFC 1122 section 3.2.1.3: // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { + if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { stats.IP.InvalidSourceAddressesReceived.Increment() return } + // Make sure the source address is not a subnet-local broadcast address. + if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + if subnet.IsBroadcast(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() + return + } + } + + // The destination address should be an address we own or a group we joined + // for us to receive the packet. Otherwise, attempt to forward the packet. + if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { + subnet := addressEndpoint.AddressWithPrefix().Subnet() + addressEndpoint.DecRef() + pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + } else if !e.IsInGroup(dstAddr) { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } - pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + _ = e.forwardPacket(pkt) + return + } // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. @@ -692,6 +723,12 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { e.handleICMP(pkt) return } + if p == header.IGMPProtocolNumber { + e.mu.Lock() + e.mu.igmp.handleIGMP(pkt) + e.mu.Unlock() + return + } if opts := h.Options(); len(opts) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options @@ -747,7 +784,12 @@ func (e *endpoint) Close() { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err == nil { + e.mu.igmp.sendQueuedReports() + } + return ep, err } // RemovePermanentAddress implements stack.AddressableEndpoint. @@ -770,34 +812,26 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo defer e.mu.Unlock() loopback := e.nic.IsLoopback() - addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { + return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool { subnet := addressEndpoint.Subnet() // IPv4 has a notion of a subnet broadcast address and considers the // loopback interface bound to an address's whole subnet (on linux). return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) - }) - if addressEndpoint != nil { - return addressEndpoint - } - - if !allowTemp { - return nil - } - - addr := localAddr.WithPrefix() - addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB) - if err != nil { - // AddAddress only returns an error if the address is already assigned, - // but we just checked above if the address exists so we expect no error. - panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err)) - } - return addressEndpoint + }, allowTemp, tempPEB) } // AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { e.mu.RLock() defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements +// +// Precondition: igmp.ep.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) } @@ -816,28 +850,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV4MulticastAddress(addr) { - return false, tcpip.ErrBadAddress + return tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + e.mu.igmp.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mu.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mu.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) @@ -863,6 +912,8 @@ type protocol struct { hashIV uint32 fragmentation *fragmentation.Fragmentation + + options Options } // Number returns the ipv4 protocol number. @@ -987,17 +1038,23 @@ func addressToUint32(addr tcpip.Address) uint32 { return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 } -// hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number and a 32-bit number to -// generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - a := addressToUint32(r.LocalAddress) - b := addressToUint32(r.RemoteAddress) +// hashRoute calculates a hash value for the given source/destination pair using +// the addresses, transport protocol number and a 32-bit number to generate the +// hash. +func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { + a := addressToUint32(srcAddr) + b := addressToUint32(dstAddr) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -// NewProtocol returns an IPv4 network protocol. -func NewProtocol(s *stack.Stack) stack.NetworkProtocol { +// Options holds options to configure a new protocol. +type Options struct { + // IGMP holds options for IGMP. + IGMP IGMPOptions +} + +// NewProtocolWithOptions returns an IPv4 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -1007,14 +1064,22 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { } hashIV := r[buckets] - p := &protocol{ - stack: s, - ids: ids, - hashIV: hashIV, - defaultTTL: DefaultTTL, + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + options: opts, + } + p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + return p } - p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) - return p +} + +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) } func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { @@ -1129,6 +1194,12 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres } pointer := tsOpt.Pointer() + // RFC 791 page 22 states: "The smallest legal value is 5." + // Since the pointer is 1 based, and the header is 4 bytes long the + // pointer must point beyond the header therefore 4 or less is bad. + if pointer <= header.IPv4OptionTimestampHdrLength { + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } // To simplify processing below, base further work on the array of timestamps // beyond the header, rather than on the whole option. Also to aid // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. @@ -1215,7 +1286,15 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength } - nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based. + pointer := rrOpt.Pointer() + // RFC 791 page 20 states: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // Since the pointer is 1 based, and the header is 3 bytes long the + // pointer must point beyond the header therefore 3 or less is bad. + if pointer <= header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } // RFC 791 page 21 says // If the route data area is already full (the pointer exceeds the @@ -1230,14 +1309,14 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // do this (as do most implementations). It is probable that the inclusion // of these words is a copy/paste error from the timestamp option where // there are two failure reasons given. - if nextSlot >= optlen { + if pointer > optlen { return 0, nil } // The data area isn't full but there isn't room for a new entry. // Either Length or Pointer could be bad. We must select Pointer for Linux - // compatibility, even if only the length is bad. - if nextSlot+header.IPv4AddressSize > optlen { + // compatibility, even if only the length is bad. NB. pointer is 1 based. + if pointer+header.IPv4AddressSize > optlen+1 { if false { // This is what we would do if we were not being Linux compatible. // Check for bad pointer or length value. Must be a multiple of 4 after diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 4e4e1f3b4..9e2d2cfd6 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -103,105 +103,6 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested -// fields when options are supplied. -func TestIPv4EncodeOptions(t *testing.T) { - tests := []struct { - name string - options header.IPv4Options - encodedOptions header.IPv4Options // reply should look like this - wantIHL int - }{ - { - name: "valid no options", - wantIHL: header.IPv4MinimumSize, - }, - { - name: "one byte options", - options: header.IPv4Options{1}, - encodedOptions: header.IPv4Options{1, 0, 0, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "two byte options", - options: header.IPv4Options{1, 1}, - encodedOptions: header.IPv4Options{1, 1, 0, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "three byte options", - options: header.IPv4Options{1, 1, 1}, - encodedOptions: header.IPv4Options{1, 1, 1, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "four byte options", - options: header.IPv4Options{1, 1, 1, 1}, - encodedOptions: header.IPv4Options{1, 1, 1, 1}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "five byte options", - options: header.IPv4Options{1, 1, 1, 1, 1}, - encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, - wantIHL: header.IPv4MinimumSize + 8, - }, - { - name: "thirty nine byte options", - options: header.IPv4Options{ - 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, - 25, 26, 27, 28, 29, 30, 31, 32, - 33, 34, 35, 36, 37, 38, 39, - }, - encodedOptions: header.IPv4Options{ - 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, - 25, 26, 27, 28, 29, 30, 31, 32, - 33, 34, 35, 36, 37, 38, 39, 0, - }, - wantIHL: header.IPv4MinimumSize + 40, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - paddedOptionLength := test.options.SizeWithPadding() - ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength) - hdr := buffer.NewPrependable(int(totalLen)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - // To check the padding works, poison the last byte of the options space. - if paddedOptionLength != len(test.options) { - ip.SetHeaderLength(uint8(ipHeaderLength)) - ip.Options()[paddedOptionLength-1] = 0xff - ip.SetHeaderLength(0) - } - ip.Encode(&header.IPv4Fields{ - Options: test.options, - }) - options := ip.Options() - wantOptions := test.encodedOptions - if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { - t.Errorf("got IHL of %d, want %d", got, want) - } - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(wantOptions) == 0 && len(options) == 0 { - return - } - - if diff := cmp.Diff(wantOptions, options); diff != "" { - t.Errorf("options mismatch (-want +got):\n%s", diff) - } - }) - } -} - func TestForwarding(t *testing.T) { const ( nicID1 = 1 @@ -453,14 +354,6 @@ func TestIPv4Sanity(t *testing.T) { replyOptions: header.IPv4Options{1, 1, 0, 0}, }, { - name: "Check option padding", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 1}, - replyOptions: header.IPv4Options{1, 1, 1, 0}, - }, - { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, maxTotalLength: ipv4.MaxTotalSize, @@ -583,7 +476,7 @@ func TestIPv4Sanity(t *testing.T) { 68, 7, 5, 0, // ^ ^ Linux points here which is wrong. // | Not a multiple of 4 - 1, 2, 3, + 1, 2, 3, 0, }, shouldFail: true, expectErrorICMP: true, @@ -662,6 +555,56 @@ func TestIPv4Sanity(t *testing.T) { }, }, { + // Timestamp pointer uses one based counting so 0 is invalid. + name: "timestamp pointer invalid", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, 0, 0x00, + // ^ 0 instead of 5 or more. + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Timestamp pointer cannot be less than 5. It must point past the header + // which is 4 bytes. (1 based counting) + name: "timestamp pointer too small by 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, + // ^ header is 4 bytes, so 4 should fail. + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "valid timestamp pointer", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, + // ^ header is 4 bytes, so 5 should succeed. + 0, 0, 0, 0, + }, + replyOptions: header.IPv4Options{ + 68, 8, 9, 0x00, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { // Needs 8 bytes for a type 1 timestamp but there are only 4 free. name: "bad timer element alignment", maxTotalLength: ipv4.MaxTotalSize, @@ -792,7 +735,61 @@ func TestIPv4Sanity(t *testing.T) { }, }, { - // Confirm linux bug for bug compatibility. + // Pointer uses one based counting so 0 is invalid. + name: "record route pointer zero", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, 0, // 3 byte header + 0, 0, 0, 0, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Pointer must be 4 or more as it must point past the 3 byte header + // using 1 based counting. 3 should fail. + name: "record route pointer too small by 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header + 0, 0, 0, 0, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Pointer must be 4 or more as it must point past the 3 byte header + // using 1 based counting. Check 4 passes. (Duplicates "single + // record route with room") + name: "valid record route pointer", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: header.IPv4Options{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm Linux bug for bug compatibility. // Linux returns slot 22 but the error is in slot 21. name: "multiple record route with not enough room", maxTotalLength: ipv4.MaxTotalSize, @@ -863,8 +860,10 @@ func TestIPv4Sanity(t *testing.T) { }, }) - paddedOptionLength := test.options.SizeWithPadding() - ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength + if len(test.options)%4 != 0 { + t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) + } + ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } @@ -883,11 +882,6 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } - // To check the padding works, poison the options space. - if paddedOptionLength != len(test.options) { - ip.SetHeaderLength(uint8(ipHeaderLength)) - ip.Options()[paddedOptionLength-1] = 0x01 - } ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, @@ -895,10 +889,19 @@ func TestIPv4Sanity(t *testing.T) { TTL: test.TTL, SrcAddr: remoteIPv4Addr, DstAddr: ipv4Addr.Address, - Options: test.options, }) if test.headerLength != 0 { ip.SetHeaderLength(test.headerLength) + } else { + // Set the calculated header length, since we may manually add options. + ip.SetHeaderLength(uint8(ipHeaderLength)) + } + if len(test.options) != 0 { + // Copy options manually. We do not use Encode for options so we can + // verify malformed options with handcrafted payloads. + if want, got := copy(ip.Options(), test.options), len(test.options); want != got { + t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) + } } ip.SetChecksum(0) ipHeaderChecksum := ip.CalculateChecksum() @@ -1003,7 +1006,7 @@ func TestIPv4Sanity(t *testing.T) { } // If the IP options change size then the packet will change size, so // some IP header fields will need to be adjusted for the checks. - sizeChange := len(test.replyOptions) - paddedOptionLength + sizeChange := len(test.replyOptions) - len(test.options) checker.IPv4(t, replyIPHeader, checker.IPv4HeaderLength(ipHeaderLength+sizeChange), @@ -2320,6 +2323,28 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, + { + name: "Two fragments with MF flag reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload4Addr1ToAddr2[:65512], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 65512, + payload: ipv4Payload4Addr1ToAddr2[65512:], + }, + }, + expectedPayloads: nil, + }, } for _, test := range tests { @@ -2513,7 +2538,7 @@ func TestWriteStats(t *testing.T) { test.setup(t, rt.Stack()) - nWritten, _ := writer.writePackets(&rt, pkts) + nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) @@ -2530,7 +2555,7 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) @@ -2644,8 +2669,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2687,8 +2712,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2736,8 +2761,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != arp.ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) } - if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress) } rep := header.ARP(p.Pkt.NetworkHeader().View()) if got := rep.Op(); got != header.ARPRequest { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 0ac24a6fb..afa45aefe 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -8,6 +8,7 @@ go_library( "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "mld.go", "ndp.go", ], visibility = ["//visibility:public"], @@ -19,6 +20,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -49,3 +51,19 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "ipv6_x_test", + size = "small", + srcs = ["mld_test.go"], + deps = [ + ":ipv6", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index beb8f562e..6ee162713 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -126,8 +126,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := e.protocol.stack.Stats().ICMP - sent := stats.V6PacketsSent - received := stats.V6PacketsReceived + sent := stats.V6.PacketsSent + received := stats.V6.PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. @@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } // TODO(b/112892170): Meaningfully handle all ICMP types. - switch h.Type() { + switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) @@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(packet.NDPPayload()) + na := header.NDPNeighborAdvert(packet.MessageBody()) // As per RFC 4861 section 7.2.4: // @@ -644,8 +644,39 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } + case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + received.MulticastListenerQuery.Increment() + case header.ICMPv6MulticastListenerReport: + received.MulticastListenerReport.Increment() + case header.ICMPv6MulticastListenerDone: + received.MulticastListenerDone.Increment() + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) + } + + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { + received.Invalid.Increment() + return + } + + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + e.mu.Lock() + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerReport: + e.mu.Lock() + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerDone: + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) + } + default: - received.Invalid.Increment() + received.Unrecognized.Increment() } } @@ -681,12 +712,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(packet.NDPPayload()) + ns := header.NDPNeighborSolicit(packet.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - stat := p.stack.Stats().ICMP.V6PacketsSent + stat := p.stack.Stats().ICMP.V6.PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, @@ -796,7 +827,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi allowResponseToMulticast = reason.respondToMulticast } - if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any { + isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst) + if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any { return nil } @@ -812,8 +844,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi // If we are operating as a router, do not use the packet's destination // address as the response's source address as we should not own the // destination address of a packet we are forwarding. + // + // If the packet was originally destined to a multicast address, then do not + // use the packet's destination address as the source for the response ICMP + // packet as "multicast addresses must not be used as source addresses in IPv6 + // packets", as per RFC 4291 section 2.7. localAddr := origIPHdrDst - if _, ok := reason.(*icmpReasonHopLimitExceeded); ok { + if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { localAddr = "" } // Even if we were able to receive a packet from some remote, we may not have @@ -827,7 +864,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi defer route.Release() stats := p.stack.Stats().ICMP - sent := stats.V6PacketsSent + sent := stats.V6.PacketsSent if !p.stack.AllowICMPMessage() { sent.RateLimited.Increment() return nil diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 9bc02d851..02b18e9a5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -150,9 +150,9 @@ func (*testInterface) Promiscuous() bool { func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + NetProto: protocol, } + r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -271,6 +271,22 @@ func TestICMPCounts(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -280,11 +296,11 @@ func TestICMPCounts(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) ep.HandlePacket(pkt) } @@ -301,7 +317,7 @@ func TestICMPCounts(t *testing.T) { // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { if got, want := s.Value(), uint64(1); got != want { t.Errorf("got %s = %d, want = %d", name, got, want) @@ -413,6 +429,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -422,11 +454,11 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) ep.HandlePacket(pkt) } @@ -443,7 +475,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { if got, want := s.Value(), uint64(1); got != want { t.Errorf("got %s = %d, want = %d", name, got, want) @@ -568,8 +600,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. return } - if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { - t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) + if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr { + t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr) } // Pull the full payload since network header. Needed for header.IPv6 to @@ -821,11 +853,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), @@ -833,7 +865,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) @@ -898,11 +930,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1016,11 +1048,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(icmpSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1028,7 +1060,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid typStat := typ.statCounter(stats) @@ -1076,11 +1108,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1195,11 +1227,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(size + payloadSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), @@ -1207,7 +1239,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid typStat := typ.statCounter(stats) @@ -1349,8 +1381,8 @@ func TestLinkAddressRequest(t *testing.T) { if !ok { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) } if pkt.Route.RemoteAddress != test.expectedRemoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) @@ -1413,11 +1445,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1431,8 +1463,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1455,11 +1487,11 @@ func TestPacketQueing(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1473,8 +1505,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1524,8 +1556,8 @@ func TestPacketQueing(t *testing.T) { t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) } snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1543,7 +1575,7 @@ func TestPacketQueing(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) @@ -1554,11 +1586,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1592,7 +1624,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1612,7 +1644,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1629,7 +1661,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1645,7 +1677,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1662,7 +1694,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1683,7 +1715,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1702,7 +1734,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1722,7 +1754,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1796,11 +1828,11 @@ func TestCallsToNeighborCache(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: test.source, - DstAddr: test.destination, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.source, + DstAddr: test.destination, }) ep.HandlePacket(pkt) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7a00f6314..a49b5ac77 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "fmt" "hash/fnv" + "math" "sort" "sync/atomic" "time" @@ -34,7 +35,9 @@ import ( ) const ( + // ReassembleTimeout controls how long a fragment will be held. // As per RFC 8200 section 4.5: + // // If insufficient fragments are received to complete reassembly of a packet // within 60 seconds of the reception of the first-arriving fragment of that // packet, reassembly of that packet must be abandoned. @@ -83,6 +86,7 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState + mld mldState } } @@ -118,6 +122,45 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// onAddressAssignedLocked handles an address being assigned. +// +// Precondition: e.mu must be exclusively locked. +func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, ... + // + // If we just completed DAD for a link-local address, then attempt to send any + // queued MLD reports. Note, we may have sent reports already for some of the + // groups before we had a valid link-local address to use as the source for + // the MLD messages, but that was only so that MLD snooping switches are aware + // of our membership to groups - routers would not have handled those reports. + // + // As per RFC 3590 section 4, + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + if header.IsV6LinkLocalAddress(addr) { + e.mu.mld.sendQueuedReports() + } +} + // InvalidateDefaultRouter implements stack.NDPEndpoint. func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() @@ -224,6 +267,12 @@ func (e *endpoint) Enable() *tcpip.Error { return nil } + // Groups may have been joined when the endpoint was disabled, or the + // endpoint may have left groups from the perspective of MLD when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + e.mu.mld.initializeAll() + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -241,8 +290,10 @@ func (e *endpoint) Enable() *tcpip.Error { // (NDP NS) messages may be sent to the All-Nodes multicast group if the // source address of the NDP NS is the unspecified address, as per RFC 4861 // section 7.2.4. - if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { - return err + if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err)) } // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent @@ -251,7 +302,7 @@ func (e *endpoint) Enable() *tcpip.Error { // Addresses may have aleady completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. var err *tcpip.Error - e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { addr := addressEndpoint.AddressWithPrefix().Address if !header.IsV6UnicastAddress(addr) { return true @@ -273,7 +324,7 @@ func (e *endpoint) Enable() *tcpip.Error { } // Do not auto-generate an IPv6 link-local address for loopback devices. - if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() { + if e.protocol.options.AutoGenLinkLocal && !e.nic.IsLoopback() { // The valid and preferred lifetime is infinite for the auto-generated // link-local address. e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) @@ -322,7 +373,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.Enabled() { return } @@ -331,9 +382,17 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } + + // Leave groups from the perspective of MLD so that routers know that + // we are no longer interested in the group. + e.mu.mld.softLeaveAll() + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -341,7 +400,7 @@ func (e *endpoint) disableLocked() { // Precondition: e.mu must be write locked. func (e *endpoint) stopDADForPermanentAddressesLocked() { // Stop DAD for all the tentative unicast addresses. - e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { if addressEndpoint.GetKind() != stack.PermanentTentative { return true } @@ -373,19 +432,27 @@ func (e *endpoint) MTU() uint32 { // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { + // TODO(gvisor.dev/issues/5035): The maximum header length returned here does + // not open the possibility for the caller to know about size required for + // extension headers. return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) { + extHdrsLen := extensionHeaders.Length() + length := pkt.Size() + extensionHeaders.Length() + if length > math.MaxUint16 { + panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16)) + } + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(params.Protocol), - HopLimit: params.TTL, - TrafficClass: params.TOS, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(length), + TransportProtocol: params.Protocol, + HopLimit: params.TTL, + TrafficClass: params.TOS, + SrcAddr: srcAddr, + DstAddr: dstAddr, + ExtensionHeaders: extensionHeaders, }) pkt.NetworkProtocolNumber = ProtocolNumber } @@ -440,7 +507,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */) // iptables filtering. All packets that reach here are locally // generated. @@ -529,7 +596,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r, pb, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */) networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { @@ -737,8 +804,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } - addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) - if addressEndpoint == nil { + // The destination address should be an address we own or a group we joined + // for us to receive the packet. Otherwise, attempt to forward the packet. + if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { stats.IP.InvalidDestinationAddressesReceived.Increment() return @@ -747,7 +817,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { _ = e.forwardPacket(pkt) return } - addressEndpoint.DecRef() // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). @@ -1090,9 +1159,16 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. + // The location of the Next Header field is in a different place in + // the initial IPv6 header than it is in the extension headers so + // treat it specially. + prevHdrIDOffset := uint32(header.IPv6NextHeaderOffset) + if previousHeaderStart != 0 { + prevHdrIDOffset = previousHeaderStart + } _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, - pointer: it.ParseOffset(), + pointer: prevHdrIDOffset, }, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) @@ -1100,12 +1176,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(&icmpReasonParameterProblem{ - code: header.ICMPv6UnknownHeader, - pointer: it.ParseOffset(), - }, pkt) - stats.UnknownProtocolRcvdPackets.Increment() - return + // Since the iterator returns IPv6RawPayloadHeader for unknown Extension + // Header IDs this should never happen unless we missed a supported type + // here. + panic(fmt.Sprintf("unrecognized type from it.Next() = %T", extHdr)) + } } } @@ -1153,11 +1228,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre return addressEndpoint, nil } - snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { - return nil, err - } - addressEndpoint.SetKind(stack.PermanentTentative) if e.Enabled() { @@ -1166,6 +1236,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } } + snmc := header.SolicitedNodeAddr(addr.Address) + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) + } + return addressEndpoint, nil } @@ -1211,7 +1288,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + // The endpoint may have already left the multicast group. + if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1234,7 +1312,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool { // // Precondition: e.mu must be read or write locked. func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint { - return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr) + return e.mu.addressableEndpointState.GetAddress(localAddr) } // MainAddress implements stack.AddressableEndpoint. @@ -1266,6 +1344,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) } +// getLinkLocalAddressRLocked returns a link-local address from the primary list +// of addresses, if one is available. +// +// See stack.PrimaryEndpointBehavior for more details about the primary list. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { + var linkLocalAddr tcpip.Address + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.IsAssigned(false /* allowExpired */) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + linkLocalAddr = addr + return false + } + } + return true + }) + return linkLocalAddr +} + // acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress // but with locking requirements. // @@ -1285,10 +1383,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // Create a candidate set of available addresses we can potentially use as a // source address. var cs []addrCandidate - e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { // If r is not valid for outgoing connections, it is not a valid endpoint. if !addressEndpoint.IsAssigned(allowExpired) { - return + return true } addr := addressEndpoint.AddressWithPrefix().Address @@ -1304,6 +1402,8 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address addressEndpoint: addressEndpoint, scope: scope, }) + + return true }) remoteScope, err := header.ScopeForIPv6Address(remoteAddr) @@ -1376,28 +1476,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return false, tcpip.ErrBadAddress + return tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + e.mu.mld.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mu.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mu.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) @@ -1405,7 +1520,8 @@ var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { - stack *stack.Stack + stack *stack.Stack + options Options mu struct { sync.RWMutex @@ -1429,26 +1545,6 @@ type protocol struct { forwarding uint32 fragmentation *fragmentation.Fragmentation - - // ndpDisp is the NDP event dispatcher that is used to send the netstack - // integrator NDP related events. - ndpDisp NDPDispatcher - - // ndpConfigs is the default NDP configurations used by an IPv6 endpoint. - ndpConfigs NDPConfigurations - - // opaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. - opaqueIIDOpts OpaqueInterfaceIdentifierOptions - - // tempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - tempIIDSeed []byte - - // autoGenIPv6LinkLocal determines whether or not the stack attempts to - // auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. - autoGenIPv6LinkLocal bool } // Number returns the ipv6 protocol number. @@ -1481,16 +1577,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp = ndpState{ - ep: e, - configs: p.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - e.mu.ndp.initializeTempAddrState() + e.mu.ndp.init(e) + e.mu.mld.init(e) + e.mu.Unlock() p.mu.Lock() defer p.mu.Unlock() @@ -1613,17 +1704,17 @@ type Options struct { // NDPConfigs is the default NDP configurations used by interfaces. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determines whether or not the stack attempts to - // auto-generate an IPv6 link-local address for newly enabled non-loopback + // AutoGenLinkLocal determines whether or not the stack attempts to + // auto-generate a link-local address for newly enabled non-loopback // NICs. // // Note, setting this to true does not mean that a link-local address is // assigned right away, or at all. If Duplicate Address Detection is enabled, // an address is only assigned if it successfully resolves. If it fails, no - // further attempts are made to auto-generate an IPv6 link-local adddress. + // further attempts are made to auto-generate a link-local adddress. // // The generated link-local address follows RFC 4291 Appendix A guidelines. - AutoGenIPv6LinkLocal bool + AutoGenLinkLocal bool // NDPDisp is the NDP event dispatcher that an integrator can provide to // receive NDP related events. @@ -1647,6 +1738,9 @@ type Options struct { // seed that is too small would reduce randomness and increase predictability, // defeating the purpose of temporary SLAAC addresses. TempIIDSeed []byte + + // MLD holds options for MLD. + MLD MLDOptions } // NewProtocolWithOptions returns an IPv6 network protocol. @@ -1658,15 +1752,11 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ - stack: s, + stack: s, + options: opts, + ids: ids, hashIV: hashIV, - - ndpDisp: opts.NDPDisp, - ndpConfigs: opts.NDPConfigs, - opaqueIIDOpts: opts.OpaqueIIDOpts, - tempIIDSeed: opts.TempIIDSeed, - autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, } p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[*endpoint]struct{}) @@ -1712,24 +1802,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea fragPkt.NetworkProtocolNumber = ProtocolNumber originalIPHeadersLength := len(originalIPHeaders) - fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + + s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + M: more, + Identification: id, + }} + + fragmentIPHeadersLength := originalIPHeadersLength + s.Length() fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) - fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) } - fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) - fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) - fragmentHeader.Encode(&header.IPv6FragmentFields{ - M: more, - FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), - Identification: id, - NextHeader: uint8(transportProto), - }) + nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:]) + + fragmentIPHeaders.SetNextHeader(nextHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) return fragPkt, more } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index a671d4bac..5f07d3af8 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -51,6 +51,7 @@ const ( fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) + unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier) extraHeaderReserve = 50 ) @@ -68,18 +69,18 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived if got := stats.NeighborAdvert.Value(); got != want { t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) @@ -126,11 +127,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -573,6 +574,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { expectICMP: false, }, { + name: "unknown next header (first)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, 63, 4, 1, 2, 3, 4, + }, unknownHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6NextHeaderOffset, + }, + { + name: "unknown next header (not first)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + unknownHdrID, 0, + 63, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { name: "destination with unknown option skippable action", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ @@ -755,11 +783,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { pointer: header.IPv6FixedHeaderSize, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, - shouldAccept: false, - }, - { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ @@ -873,7 +896,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { Length: uint16(udpLength), }) copy(u.Payload(), udpPayload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } + + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength)) sum = header.Checksum(udpPayload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) @@ -884,16 +913,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - dstAddr := tcpip.Address(addr2) - if test.multicast { - dstAddr = header.IPv6AllNodesMulticastAddress - } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), - NextHeader: ipv6NextHdr, - HopLimit: 255, - SrcAddr: addr1, - DstAddr: dstAddr, + // We're lying about transport protocol here to be able to generate + // raw extension headers from the test definitions. + TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr), + HopLimit: 255, + SrcAddr: addr1, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -982,9 +1009,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { udpPayload2Length = 128 // Used to test cases where the fragment blocks are not a multiple of // the fragment block size of 8 (RFC 8200 section 4.5). - udpPayload3Length = 127 - udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize - fragmentExtHdrLen = 8 + udpPayload3Length = 127 + udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize + udpMaximumSizeMinus15 = header.UDPMaximumSize - 15 + fragmentExtHdrLen = 8 // Note, not all routing extension headers will be 8 bytes but this test // uses 8 byte routing extension headers for most sub tests. routingExtHdrLen = 8 @@ -1328,14 +1356,14 @@ func TestReceiveIPv6Fragments(t *testing.T) { dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+65520, + fragmentExtHdrLen+udpMaximumSizeMinus15, []buffer.View{ // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload4Addr1ToAddr2[:65520], + ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, ), }, @@ -1344,14 +1372,17 @@ func TestReceiveIPv6Fragments(t *testing.T) { dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520, + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, []buffer.View{ // Fragment extension header. // - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}), + // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + udpMaximumSizeMinus15 >> 8, + udpMaximumSizeMinus15 & 0xff, + 0, 0, 0, 1}), - ipv6Payload4Addr1ToAddr2[65520:], + ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, ), }, @@ -1359,6 +1390,47 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, { + name: "Two fragments with MF flag reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+udpMaximumSizeMinus15, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + udpMaximumSizeMinus15 >> 8, + (udpMaximumSizeMinus15 & 0xff) + 1, + 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { name: "Two fragments with per-fragment routing header with zero segments left", fragments: []fragmentData{ { @@ -1877,10 +1949,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, + // We're lying about transport protocol here so that we can generate + // raw extension headers for the tests. + TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr), + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, }) vv := hdr.View().ToVectorisedView() @@ -1925,7 +1999,7 @@ func TestInvalidIPv6Fragments(t *testing.T) { type fragmentData struct { ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6FragmentFields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr payload []byte } @@ -1944,14 +2018,13 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 9, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 9, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0 >> 3, M: true, Identification: ident, @@ -1971,14 +2044,13 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, M: false, Identification: ident, @@ -2019,10 +2091,9 @@ func TestInvalidIPv6Fragments(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - ip.Encode(&f.ipv6Fields) - - fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - fragHDR.Encode(&f.ipv6FragmentFields) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) @@ -2084,7 +2155,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { type fragmentData struct { ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6FragmentFields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr payload []byte } @@ -2098,14 +2169,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2120,14 +2190,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2136,14 +2205,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2158,14 +2226,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2180,14 +2247,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2196,14 +2262,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2218,14 +2283,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2234,14 +2298,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2280,10 +2343,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - ip.Encode(&f.ipv6Fields) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - fragHDR.Encode(&f.ipv6FragmentFields) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) @@ -2439,7 +2503,7 @@ func TestWriteStats(t *testing.T) { test.setup(t, rt.Stack()) - nWritten, _ := writer.writePackets(&rt, pkts) + nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) @@ -2456,7 +2520,7 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) @@ -2924,11 +2988,11 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: test.TTL, + SrcAddr: remoteIPv6Addr1, + DstAddr: remoteIPv6Addr2, }) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go new file mode 100644 index 000000000..e8d1e7a79 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld.go @@ -0,0 +1,262 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited MLD reports. + // + // Obtained from RFC 2710 Section 7.10. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// MLDOptions holds options for MLD. +type MLDOptions struct { + // Enabled indicates whether MLD will be performed. + // + // When enabled, MLD may transmit MLD report and done messages when + // joining and leaving multicast groups respectively, and handle incoming + // MLD packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*mldState)(nil) + +// mldState is the per-interface MLD state. +// +// mldState.init MUST be called to initialize the MLD state. +type mldState struct { + // The IPv6 endpoint this mldState is for. + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState +} + +// Enabled implements ip.MulticastGroupProtocol. +func (mld *mldState) Enabled() bool { + // No need to perform MLD on loopback interfaces since they don't have + // neighbouring nodes. + return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled() +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + return err +} + +// init sets up an mldState struct, and is required to be called before using +// a new mldState. +// +// Must only be called once for the lifetime of mld. +func (mld *mldState) init(ep *endpoint) { + mld.ep = ep + mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: mld, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv6AllNodesMulticastAddress, + }) +} + +// handleMulticastListenerQuery handles a query message. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { + mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) +} + +// handleMulticastListenerReport handles a report message. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { + mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress()) +} + +// joinGroup handles joining a new group and sending and scheduling the required +// messages. +// +// If the group is already joined, returns tcpip.ErrDuplicateAddress. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) joinGroup(groupAddress tcpip.Address) { + mld.genericMulticastProtocol.JoinGroupLocked(groupAddress) +} + +// isInGroup returns true if the specified group has been joined locally. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { + return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Done message, if +// required. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { + return nil + } + + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of MLD, but remains +// joined locally. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) softLeaveAll() { + mld.genericMulticastProtocol.MakeAllNonMemberLocked() +} + +// initializeAll attemps to initialize the MLD state for each group that has +// been joined locally. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) initializeAll() { + mld.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) sendQueuedReports() { + mld.genericMulticastProtocol.SendQueuedReportsLocked() +} + +// writePacket assembles and sends an MLD packet. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { + sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + var mldStat *tcpip.StatCounter + switch mldType { + case header.ICMPv6MulticastListenerReport: + mldStat = sentStats.MulticastListenerReport + case header.ICMPv6MulticastListenerDone: + mldStat = sentStats.MulticastListenerDone + default: + panic(fmt.Sprintf("unrecognized mld type = %d", mldType)) + } + + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) + icmp.SetType(mldType) + header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert + // option in a Hop-by-Hop Options header. + // + // However, this would cause problems with Duplicate Address Detection with + // the first address as MLD snooping switches may not send multicast traffic + // that DAD depends on to the node performing DAD without the MLD report, as + // documented in RFC 4816: + // + // Note that when a node joins a multicast address, it typically sends a + // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810] + // for the multicast address. In the case of Duplicate Address + // Detection, the MLD report message is required in order to inform MLD- + // snooping switches, rather than routers, to forward multicast packets. + // In the above description, the delay for joining the multicast address + // thus means delaying transmission of the corresponding MLD report + // message. Since the MLD specifications do not request a random delay + // to avoid race conditions, just delaying Neighbor Solicitation would + // cause congestion by the MLD report messages. The congestion would + // then prevent the MLD-snooping switches from working correctly and, as + // a result, prevent Duplicate Address Detection from working. The + // requirement to include the delay for the MLD report in this case + // avoids this scenario. [RFC3590] also talks about some interaction + // issues between Duplicate Address Detection and MLD, and specifies + // which source address should be used for the MLD report in this case. + // + // As per RFC 3590 section 4, we should still send out MLD reports with an + // unspecified source address if we do not have an assigned link-local + // address to use as the source address to ensure DAD works as expected on + // networks with MLD snooping switches: + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + localAddress := mld.ep.getLinkLocalAddressRLocked() + if len(localAddress) == 0 { + localAddress = header.IPv6Any + } + + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + + extensionHeaders := header.IPv6ExtHdrSerializer{ + header.IPv6SerializableHopByHopExtHdr{ + &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, + }, + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(), + Data: buffer.View(icmp).ToVectorisedView(), + }) + + mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.MLDHopLimit, + }, extensionHeaders) + if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sentStats.Dropped.Increment() + return false, err + } + mldStat.Increment() + return localAddress != header.IPv6Any, nil +} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go new file mode 100644 index 000000000..e2778b656 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -0,0 +1,297 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" +) + +var ( + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) + globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) +) + +func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { + t.Helper() + + checker.IPv6WithExtHdr(t, p, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(localAddress), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(mldType, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + +func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + // The stack will join an address's solicited node multicast address when + // an address is added. An MLD report message should be sent for the + // solicited-node group. + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + + // The stack will leave an address's solicited node multicast address when + // an address is removed. An MLD done message should be sent for the + // solicited-node group. + if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a done message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) + } +} + +func TestSendQueuedMLDReports(t *testing.T) { + const ( + nicID = 1 + maxReports = 2 + ) + + tests := []struct { + name string + dadTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD Disabled", + dadTransmits: 0, + retransmitTimer: 0, + }, + { + name: "DAD Enabled", + dadTransmits: 1, + retransmitTimer: time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dadTransmits, + RetransmitTimer: test.retransmitTimer, + }, + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + Clock: clock, + }) + + // Allow space for an extra packet so we can observe packets that were + // unexpectedly sent. + e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + resolveDAD := func(addr, snmc tcpip.Address) { + clock.Advance(dadResolutionTime) + if p, ok := e.Read(); !ok { + t.Fatal("expected DAD packet") + } else { + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr), + checker.NDPNSOptions(nil), + )) + } + } + + var reportCounter uint64 + reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + var doneCounter uint64 + doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + + // Joining a group without an assigned address should send an MLD report + // with the unspecified address. + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalMulticastAddr) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a global address should not send reports for the already joined + // group since we should only send queued reports when a link-local + // addres sis assigned. + // + // Note, we will still expect to send a report for the global address's + // solicited node address from the unspecified address as per RFC 3590 + // section 4. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) + } + if dadResolutionTime != 0 { + // Reports should not be sent when the address resolves. + resolveDAD(globalAddr, globalAddrSNMC) + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + } + // Leave the group since we don't care about the global address's + // solicited node multicast group membership. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) + } + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a link-local address should send a report for its solicited node + // address and globalMulticastAddr. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + } + if dadResolutionTime != 0 { + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + resolveDAD(linkLocalAddr, linkLocalAddrSNMC) + } + + // We expect two batches of reports to be sent (1 batch when the + // link-local address is assigned, and another after the maximum + // unsolicited report interval. + for i := 0; i < 2; i++ { + // We expect reports to be sent (one for globalMulticastAddr and another + // for linkLocalAddrSNMC). + reportCounter += maxReports + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + + addrs := map[tcpip.Address]bool{ + globalMulticastAddr: false, + linkLocalAddrSNMC: false, + } + for _ = range addrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) + } + + addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() + if seen, ok := addrs[addr]; !ok { + t.Fatalf("got unexpected packet destined to %s", addr) + } else if seen { + t.Fatalf("got another packet destined to %s", addr) + } + + addrs[addr] = true + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) + + clock.Advance(ipv6.UnsolicitedReportIntervalMax) + } + } + + // Should not send any more reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 40da011f8..d515eb622 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -20,6 +20,7 @@ import ( "math/rand" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + // The IPv6 endpoint this ndpState is for. ep *endpoint @@ -471,17 +475,8 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - rtrSolicit struct { - // The timer used to send the next router solicitation message. - timer tcpip.Timer - - // Used to let the Router Solicitation timer know that it has been stopped. - // - // Must only be read from or written to while protected by the lock of - // the IPv6 endpoint this ndpState is associated with. MUST be set when the - // timer is set. - done *bool - } + // The job used to send the next router solicitation message. + rtrSolicitJob *tcpip.Job // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -507,7 +502,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer tcpip.Timer + job *tcpip.Job // Used to let the DAD timer know that it has been stopped. // @@ -648,96 +643,73 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // Consider DAD to have resolved even if no DAD messages were actually // transmitted. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } + ndp.ep.onAddressAssignedLocked(addr) return nil } - var done bool - var timer tcpip.Timer - // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the IPv6 endpoint's lock. This is effectively - // the same as starting a goroutine but we use a timer that fires immediately - // so we can reset it for the next DAD iteration. - timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { - ndp.ep.mu.Lock() - defer ndp.ep.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD timer fired after - // another goroutine already obtained the IPv6 endpoint lock and stopped - // DAD before this function obtained the NIC lock. Simply return here and - // do nothing further. - return - } + state := dadState{ + job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + state, ok := ndp.dad[addr] + if !ok { + panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID())) + } - if addressEndpoint.GetKind() != stack.PermanentTentative { - // The endpoint should still be marked as tentative since we are still - // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) - } + if addressEndpoint.GetKind() != stack.PermanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) + } - dadDone := remaining == 0 - - var err *tcpip.Error - if !dadDone { - // Use the unspecified address as the source address when performing DAD. - addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) - - // Do not hold the lock when sending packets which may be a long running - // task or may block link address resolution. We know this is safe - // because immediately after obtaining the lock again, we check if DAD - // has been stopped before doing any work with the IPv6 endpoint. Note, - // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if - // the address was removed. - ndp.ep.mu.Unlock() - err = ndp.sendDADPacket(addr, addressEndpoint) - ndp.ep.mu.Lock() - addressEndpoint.DecRef() - } + dadDone := remaining == 0 - if done { - // If we reach this point, it means that DAD was stopped after we released - // the IPv6 endpoint's read lock and before we obtained the write lock. - return - } + var err *tcpip.Error + if !dadDone { + err = ndp.sendDADPacket(addr, addressEndpoint) + } - if dadDone { - // DAD has resolved. - addressEndpoint.SetKind(stack.Permanent) - } else if err == nil { - // DAD is not done and we had no errors when sending the last NDP NS, - // schedule the next DAD timer. - remaining-- - timer.Reset(ndp.configs.RetransmitTimer) - return - } + if dadDone { + // DAD has resolved. + addressEndpoint.SetKind(stack.Permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. + remaining-- + state.job.Schedule(ndp.configs.RetransmitTimer) + return + } - // At this point we know that either DAD is done or we hit an error sending - // the last NDP NS. Either way, clean up addr's DAD state and let the - // integrator know DAD has completed. - delete(ndp.dad, addr) + // At this point we know that either DAD is done or we hit an error + // sending the last NDP NS. Either way, clean up addr's DAD state and let + // the integrator know DAD has completed. + delete(ndp.dad, addr) - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) - } + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) + } - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) - } - }) + if dadDone { + if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the + // generation of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } - ndp.dad[addr] = dadState{ - timer: timer, - done: &done, + ndp.ep.onAddressAssignedLocked(addr) + } + }), } + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the IPv6 endpoint's lock. This is effectively + // the same as starting a goroutine but we use a timer that fires immediately + // so we can reset it for the next DAD iteration. + state.job.Schedule(0) + ndp.dad[addr] = state + return nil } @@ -745,55 +717,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // addr. // // addr must be a tentative IPv6 address on ndp's IPv6 endpoint. -// -// The IPv6 endpoint that ndp belongs to MUST NOT be locked. func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - - // Route should resolve immediately since snmc is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the IPv6 endpoint is not - // locked, the NIC could have been disabled or removed by another goroutine. - if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { - return err - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) - } - - icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) - icmpData.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(addr) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(icmpData).ToVectorisedView(), + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), + Data: buffer.View(icmp).ToVectorisedView(), }) - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt, - ); err != nil { + sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, nil /* extensionHeaders */) + + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() return err } sent.NeighborSolicit.Increment() - return nil } @@ -812,18 +760,11 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { return } - if dad.timer != nil { - dad.timer.Stop() - dad.timer = nil - - *dad.done = true - dad.done = nil - } - + dad.job.Cancel() delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) } } @@ -846,7 +787,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we // only inform the dispatcher on configuration changes. We do nothing else // with the information. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { var configuration DHCPv6ConfigurationFromNDPRA switch { case ra.ManagedAddrConfFlag(): @@ -903,20 +844,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { switch opt := opt.(type) { case header.NDPRecursiveDNSServer: - if ndp.ep.protocol.ndpDisp == nil { + if ndp.ep.protocol.options.NDPDisp == nil { continue } addrs, _ := opt.Addresses() - ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) + ndp.ep.protocol.options.NDPDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) case header.NDPDNSSearchList: - if ndp.ep.protocol.ndpDisp == nil { + if ndp.ep.protocol.options.NDPDisp == nil { continue } domainNames, _ := opt.DomainNames() - ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) + ndp.ep.protocol.options.NDPDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -964,7 +905,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) } } @@ -976,7 +917,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } @@ -1006,7 +947,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } @@ -1047,7 +988,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) } } @@ -1225,7 +1166,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint { // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return nil } @@ -1272,7 +1213,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt } dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures - if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil { + if oIID := ndp.ep.protocol.options.OpaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, @@ -1676,7 +1617,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint } addressEndpoint.SetDeprecated(true) - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } } @@ -1701,7 +1642,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } @@ -1761,7 +1702,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } @@ -1859,7 +1800,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicit.timer != nil { + if ndp.rtrSolicitJob != nil { // We are already soliciting routers. return } @@ -1876,56 +1817,14 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - var done bool - ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { - ndp.ep.mu.Lock() - if done { - // If we reach this point, it means that the RS timer fired after another - // goroutine already obtained the IPv6 endpoint lock and stopped - // solicitations. Simply return here and do nothing further. - ndp.ep.mu.Unlock() - return - } - + ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false) - if addressEndpoint == nil { - // Incase this ends up creating a new temporary address, we need to hold - // onto the endpoint until a route is obtained. If we decrement the - // reference count before obtaing a route, the address's resources would - // be released and attempting to obtain a route after would fail. Once a - // route is obtainted, it is safe to decrement the reference count since - // obtaining a route increments the address's reference count. - addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) - } - ndp.ep.mu.Unlock() - - localAddr := addressEndpoint.AddressWithPrefix().Address - r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) - addressEndpoint.DecRef() - if err != nil { - return - } - defer r.Release() - - // Route should resolve immediately since - // header.IPv6AllRoutersMulticastAddress is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the IPv6 endpoint is - // not locked, the IPv6 endpoint could have been disabled or removed by - // another goroutine. - if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { - return - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) + localAddr := header.IPv6Any + if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + localAddr = addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1936,30 +1835,31 @@ func (ndp *ndpState) startSolicitingRouters() { // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. var optsSerializer header.NDPOptionsSerializer - if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + linkAddress := ndp.ep.nic.LinkAddress() + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) { optsSerializer = header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + header.NDPSourceLinkLayerAddressOption(linkAddress), } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(icmpData.NDPPayload()) + rs := header.NDPRouterSolicit(icmpData.MessageBody()) rs.Options().Serialize(optsSerializer) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt, - ); err != nil { + sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, nil /* extensionHeaders */) + + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. @@ -1969,21 +1869,12 @@ func (ndp *ndpState) startSolicitingRouters() { remaining-- } - ndp.ep.mu.Lock() - if done || remaining == 0 { - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil - } else if ndp.rtrSolicit.timer != nil { - // Note, we need to explicitly check to make sure that - // the timer field is not nil because if it was nil but - // we still reached this point, then we know the IPv6 endpoint - // was requested to stop soliciting routers so we don't - // need to send the next Router Solicitation message. - ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) + if remaining != 0 { + ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval) } - ndp.ep.mu.Unlock() }) + ndp.rtrSolicitJob.Schedule(delay) } // stopSolicitingRouters stops soliciting routers. If routers are not currently @@ -1991,22 +1882,28 @@ func (ndp *ndpState) startSolicitingRouters() { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicit.timer == nil { + if ndp.rtrSolicitJob == nil { // Nothing to do. return } - *ndp.rtrSolicit.done = true - ndp.rtrSolicit.timer.Stop() - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil + ndp.rtrSolicitJob.Cancel() + ndp.rtrSolicitJob = nil } -// initializeTempAddrState initializes state related to temporary SLAAC -// addresses. -func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) +func (ndp *ndpState) init(ep *endpoint) { + if ndp.dad != nil { + panic("attempted to initialize NDP state twice") + } + + ndp.ep = ep + ndp.configs = ep.protocol.options.NDPConfigs + ndp.dad = make(map[tcpip.Address]dadState) + ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) + ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 37e8b1083..05a0d95b2 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -205,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -213,14 +213,14 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -311,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -319,23 +319,23 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) neighbors, err := s.Neighbors(nicID) if err != nil { @@ -591,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(nicAddr) opts := ns.Options() opts.Serialize(test.nsOpts) @@ -599,14 +599,14 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != respNSDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) } - if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -672,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(test.nsSrc) @@ -681,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, }) e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != test.naDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -777,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -785,14 +785,14 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -890,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -898,23 +898,23 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) neighbors, err := s.Neighbors(nicID) if err != nil { @@ -979,29 +979,25 @@ func TestNDPValidation(t *testing.T) { } handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - nextHdr := uint8(header.ICMPv6ProtocolNumber) - var extensions buffer.View + var extHdrs header.IPv6ExtHdrSerializer if atomicFragment { - extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) - extensions[0] = nextHdr - nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) } + extHdrsLen := extHdrs.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen, Data: payload.ToVectorisedView(), }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + len(extensions)), - NextHeader: nextHdr, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(payload) + extHdrsLen), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: hopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + ExtensionHeaders: extHdrs, }) - if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { - t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) - } ep.HandlePacket(pkt) } @@ -1122,7 +1118,7 @@ func TestNDPValidation(t *testing.T) { s.SetForwarding(ProtocolNumber, true) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) @@ -1346,19 +1342,19 @@ func TestRouterAdvertValidation(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) + copy(pkt.MessageBody(), test.ndpPayload) payloadLength := hdr.UsedLength() pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, }) - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid rxRA := stats.RouterAdvert diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go new file mode 100644 index 000000000..05d98a0a5 --- /dev/null +++ b/pkg/tcpip/network/multicast_group_test.go @@ -0,0 +1,1261 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "fmt" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + + ipv4Addr = tcpip.Address("\x0a\x00\x00\x01") + ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") + ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") + ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") + ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") + ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") + + igmpMembershipQuery = uint8(header.IGMPMembershipQuery) + igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) + igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) + igmpLeaveGroup = uint8(header.IGMPLeaveGroup) + mldQuery = uint8(header.ICMPv6MulticastListenerQuery) + mldReport = uint8(header.ICMPv6MulticastListenerReport) + mldDone = uint8(header.ICMPv6MulticastListenerDone) + + maxUnsolicitedReports = 2 +) + +var ( + // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the + // NIC will wait before sending an unsolicited report after joining a + // multicast group, in deciseconds. + unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { + const decisecond = time.Second / 10 + if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { + panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) + } + return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) + }() + + ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr) +) + +// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet +// sent to the provided address with the passed fields set. +func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv6WithExtHdr(t, payload, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(ipv6Addr), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, + checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + +// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. +func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.SrcAddr(ipv4Addr), + checker.DstAddr(remoteAddress), + // TTL for an IGMP message must be 1 as per RFC 2236 section 2. + checker.TTL(1), + checker.IPv4RouterAlert(), + checker.IGMP( + checker.IGMPType(header.IGMPType(igmpType)), + checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) + s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) + return e, s, clock +} + +func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { + t.Helper() + + igmpEnabled := v4 && mgpEnabled + mldEnabled := !v4 && mgpEnabled + + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMP: ipv4.IGMPOptions{ + Enabled: igmpEnabled, + }, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: mldEnabled, + }, + }), + }, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err) + } + + return s, clock +} + +// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join +// when it is created with an IPv6 address. +// +// To not interfere with tests, checkInitialIPv6Groups will leave the added +// address's solicited node multicast group so that the tests can all assume +// the NIC has not joined any IPv6 groups. +func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { + t.Helper() + + stats := s.Stats().ICMP.V6.PacketsSent + + reportCounter++ + if got := stats.MulticastListenerReport.Value(); got != reportCounter { + t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) + } + + // Leave the group to not affect the tests. This is fine since we are not + // testing DAD or the solicited node address specifically. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) + } + leaveCounter++ + if got := stats.MulticastListenerDone.Value(); got != leaveCounter { + t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + return reportCounter, leaveCounter +} + +// createAndInjectIGMPPacket creates and injects an IGMP packet with the +// specified fields. +// +// Note, the router alert option is not included in this packet. +// +// TODO(b/162198658): set the router alert option. +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: header.IGMPTTL, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(header.IGMPType(igmpType)) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// createAndInjectMLDPacket creates and injects an MLD packet with the +// specified fields. +// +// Note, the router alert option is not included in this packet. +// +// TODO(b/162198658): set the router alert option. +func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) { + icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize + buf := buffer.NewView(header.IPv6MinimumSize + icmpSize) + + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + HopLimit: header.MLDHopLimit, + TransportProtocol: header.ICMPv6ProtocolNumber, + SrcAddr: header.IPv4Any, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + icmp := header.ICMPv6(buf[header.IPv6MinimumSize:]) + icmp.SetType(header.ICMPv6Type(mldType)) + mld := header.MLD(icmp.MessageBody()) + mld.SetMaximumResponseDelay(uint16(maxRespDelay)) + mld.SetMulticastAddress(groupAddress) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + + e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestMGPDisabled tests that the multicast group protocol is not enabled by +// default. +func TestMGPDisabled(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + rxQuery func(*channel.Endpoint) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxQuery: func(e *channel.Endpoint) { + createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxQuery: func(e *channel.Endpoint) { + createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) + + // This NIC may join multicast groups when it is enabled but since MGP is + // disabled, no reports should be sent. + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt) + } + + // Inject a general query message. This should only trigger a report to be + // sent if the MGP was enabled. + test.rxQuery(e) + if got := test.receivedQueryStat(s).Value(); got != 1 { + t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } + }) + } +} + +func TestMGPReceiveCounters(t *testing.T) { + tests := []struct { + name string + headerType uint8 + maxRespTime byte + groupAddress tcpip.Address + statCounter func(*stack.Stack) *tcpip.StatCounter + rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address) + }{ + { + name: "IGMP Membership Query", + headerType: igmpMembershipQuery, + maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, + groupAddress: header.IPv4Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMPv1 Membership Report", + headerType: igmpv1MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.V1MembershipReport + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMPv2 Membership Report", + headerType: igmpv2MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.V2MembershipReport + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMP Leave Group", + headerType: igmpLeaveGroup, + maxRespTime: 0, + groupAddress: header.IPv4AllRoutersGroup, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.LeaveGroup + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "MLD Query", + headerType: mldQuery, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxMGPkt: createAndInjectMLDPacket, + }, + { + name: "MLD Report", + headerType: mldReport, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport + }, + rxMGPkt: createAndInjectMLDPacket, + }, + { + name: "MLD Done", + headerType: mldDone, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone + }, + rxMGPkt: createAndInjectMLDPacket, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) + + test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) + if got := test.statCounter(s).Value(); got != 1 { + t.Fatalf("got %s received = %d, want = 1", test.name, got) + } + }) + } +} + +// TestMGPJoinGroup tests that when explicitly joining a multicast group, the +// stack schedules and sends correct Membership Reports. +func TestMGPJoinGroup(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } + + // Test joining a specific address explicitly and verify a Report is sent + // immediately. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + reportCounter++ + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Verify the second report is sent by the maximum unsolicited response + // interval. + p, ok := e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) + } + clock.Advance(test.maxUnsolicitedResponseDelay) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPLeaveGroup tests that when leaving a previously joined multicast +// group the stack sends a leave/done message. +func TestMGPLeaveGroup(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + validateLeave func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + reportCounter++ + if got := test.sentReportStat(s).Value(); got != reportCounter { + t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Leaving the group should trigger an leave/done message to be sent. + if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) + } + leaveCounter++ + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a leave message to be sent") + } else { + test.validateLeave(t, p) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPQueryMessages tests that a report is sent in response to query +// messages. +func TestMGPQueryMessages(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + rxQuery func(*channel.Endpoint, uint8, tcpip.Address) + validateReport func(*testing.T, channel.PacketInfo) + maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { + createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + maxRespTimeToDuration: header.DecisecondToDuration, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { + createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + maxRespTimeToDuration: func(d uint8) time.Duration { + return time.Duration(d) * time.Millisecond + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + multicastAddr tcpip.Address + expectReport bool + }{ + { + name: "Unspecified", + multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))), + expectReport: true, + }, + { + name: "Specified", + multicastAddr: test.multicastAddr, + expectReport: true, + }, + { + name: "Specified other address", + multicastAddr: func() tcpip.Address { + addrBytes := []byte(test.multicastAddr) + addrBytes[len(addrBytes)-1]++ + return tcpip.Address(addrBytes) + }(), + expectReport: false, + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + for i := 0; i < maxUnsolicitedReports; i++ { + sentReportStat := test.sentReportStat(s) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected %d-th report message to be sent", i) + } else { + test.validateReport(t, p) + } + clock.Advance(test.maxUnsolicitedResponseDelay) + } + if t.Failed() { + t.FailNow() + } + + // Should not send any more packets until a query. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + // Receive a query message which should trigger a report to be sent at + // some time before the maximum response time if the report is + // targeted at the host. + const maxRespTime = 100 + test.rxQuery(e, maxRespTime, subTest.multicastAddr) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p.Pkt) + } + + if subTest.expectReport { + clock.Advance(test.maxRespTimeToDuration(maxRespTime)) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } + }) + } +} + +// TestMGPQueryMessages tests that no further reports or leave/done messages +// are sent after receiving a report. +func TestMGPReportMessages(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + rxReport func(*channel.Endpoint) + validateReport func(*testing.T, channel.PacketInfo) + maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + rxReport: func(e *channel.Endpoint) { + createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + maxRespTimeToDuration: header.DecisecondToDuration, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + rxReport: func(e *channel.Endpoint) { + createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + maxRespTimeToDuration: func(d uint8) time.Duration { + return time.Duration(d) * time.Millisecond + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Receiving a report for a group we joined should cancel any further + // reports. + test.rxReport(e) + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("sent unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Leaving a group after getting a report should not send a leave/done + // message. + if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) + } + clock.Advance(time.Hour) + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +func TestMGPWithNICLifecycle(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddrs []tcpip.Address + finalMulticastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) + validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) + getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, + finalMulticastAddr: ipv4MulticastAddr3, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { + t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) + } + addr := header.IGMP(ipv4.Payload()).GroupAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, + finalMulticastAddr: ipv6MulticastAddr3, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, addr, mldReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + + ipv6HeaderIter := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var transport header.IPv6RawPayloadHeader + for { + h, done, err := ipv6HeaderIter.Next() + if err != nil { + t.Fatalf("ipv6HeaderIter.Next(): %s", err) + } + if done { + t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) + } + if t, ok := h.(header.IPv6RawPayloadHeader); ok { + transport = t + break + } + } + + if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { + t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) + } + icmpv6 := header.ICMPv6(transport.Buf.ToView()) + if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { + t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) + } + addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + sentReportStat := test.sentReportStat(s) + for _, a := range test.multicastAddrs { + if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected a report message to be sent for %s", a) + } else { + test.validateReport(t, p, a) + } + } + if t.Failed() { + t.FailNow() + } + + // Leave messages should be sent for the joined groups when the NIC is + // disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + sentLeaveStat := test.sentLeaveStat(s) + leaveCounter += uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i, _ := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + + test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Reports should be sent for the joined groups when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter += uint64(len(test.multicastAddrs)) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i, _ := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) report message to be sent", i) + } + + test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Joining/leaving a group while disabled should not send any messages. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + leaveCounter += uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + for i, _ := range test.multicastAddrs { + if _, ok := e.Read(); !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + } + for _, a := range test.multicastAddrs { + if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) + } + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) + } + } + if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) + } + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) + } + + // A report should only be sent for the group we last joined after + // enabling the NIC since the original groups were all left. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + clock.Advance(test.maxUnsolicitedResponseDelay) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPDisabledOnLoopback tests that the multicast group protocol is not +// performed on loopback interfaces since they have no neighbours. +func TestMGPDisabledOnLoopback(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) + + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + }) + } +} diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 7cc52985e..5c3363759 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st return n, nil } -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - - return nil -} - // Attach implements LinkEndpoint.Attach. func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 2a6c7c7c0..b60a5fd76 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -15,31 +15,350 @@ package tcpip import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" ) -// SocketOptions contains all the variables which store values for socket -// level options. +// SocketOptionsHandler holds methods that help define endpoint specific +// behavior for socket level socket options. These must be implemented by +// endpoints to get notified when socket level options are set. +type SocketOptionsHandler interface { + // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint. + OnReuseAddressSet(v bool) + + // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint. + OnReusePortSet(v bool) + + // OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint. + OnKeepAliveSet(v bool) + + // OnDelayOptionSet is invoked when TCP_NODELAY is set for an endpoint. + // Note that v will be the inverse of TCP_NODELAY option. + OnDelayOptionSet(v bool) + + // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint. + OnCorkOptionSet(v bool) + + // LastError is invoked when SO_ERROR is read for an endpoint. + LastError() *Error +} + +// DefaultSocketOptionsHandler is an embeddable type that implements no-op +// implementations for SocketOptionsHandler methods. +type DefaultSocketOptionsHandler struct{} + +var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil) + +// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet. +func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {} + +// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet. +func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {} + +// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet. +func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {} + +// OnDelayOptionSet implements SocketOptionsHandler.OnDelayOptionSet. +func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {} + +// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet. +func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {} + +// LastError implements SocketOptionsHandler.LastError. +func (*DefaultSocketOptionsHandler) LastError() *Error { + return nil +} + +// SocketOptions contains all the variables which store values for SOL_SOCKET, +// SOL_IP, SOL_IPV6 and SOL_TCP level options. // // +stateify savable type SocketOptions struct { - // mu protects fields below. - mu sync.Mutex `state:"nosave"` - broadcastEnabled bool + handler SocketOptionsHandler + + // These fields are accessed and modified using atomic operations. + + // broadcastEnabled determines whether datagram sockets are allowed to + // send packets to a broadcast address. + broadcastEnabled uint32 + + // passCredEnabled determines whether SCM_CREDENTIALS socket control + // messages are enabled. + passCredEnabled uint32 + + // noChecksumEnabled determines whether UDP checksum is disabled while + // transmitting for this socket. + noChecksumEnabled uint32 + + // reuseAddressEnabled determines whether Bind() should allow reuse of + // local address. + reuseAddressEnabled uint32 + + // reusePortEnabled determines whether to permit multiple sockets to be + // bound to an identical socket address. + reusePortEnabled uint32 + + // keepAliveEnabled determines whether TCP keepalive is enabled for this + // socket. + keepAliveEnabled uint32 + + // multicastLoopEnabled determines whether multicast packets sent over a + // non-loopback interface will be looped back. Analogous to inet->mc_loop. + multicastLoopEnabled uint32 + + // receiveTOSEnabled is used to specify if the TOS ancillary message is + // passed with incoming packets. + receiveTOSEnabled uint32 + + // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary + // message is passed with incoming packets. + receiveTClassEnabled uint32 + + // receivePacketInfoEnabled is used to specify if more inforamtion is + // provided with incoming packets such as interface index and address. + receivePacketInfoEnabled uint32 + + // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets + // being written have an IP header and the endpoint should not attach an IP + // header. + hdrIncludedEnabled uint32 + + // v6OnlyEnabled is used to determine whether an IPv6 socket is to be + // restricted to sending and receiving IPv6 packets only. + v6OnlyEnabled uint32 + + // quickAckEnabled is used to represent the value of TCP_QUICKACK option. + // It currently does not have any effect on the TCP endpoint. + quickAckEnabled uint32 + + // delayOptionEnabled is used to specify if data should be sent out immediately + // by the transport protocol. For TCP, it determines if the Nagle algorithm + // is on or off. + delayOptionEnabled uint32 + + // corkOptionEnabled is used to specify if data should be held until segments + // are full by the TCP transport protocol. + corkOptionEnabled uint32 + + // receiveOriginalDstAddress is used to specify if the original destination of + // the incoming packet should be returned as an ancillary message. + receiveOriginalDstAddress uint32 + + // mu protects the access to the below fields. + mu sync.Mutex `state:"nosave"` + + // linger determines the amount of time the socket should linger before + // close. We currently implement this option for TCP socket only. + linger LingerOption +} + +// InitHandler initializes the handler. This must be called before using the +// socket options utility. +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) { + so.handler = handler +} + +func storeAtomicBool(addr *uint32, v bool) { + var val uint32 + if v { + val = 1 + } + atomic.StoreUint32(addr, val) } // GetBroadcast gets value for SO_BROADCAST option. func (so *SocketOptions) GetBroadcast() bool { - so.mu.Lock() - defer so.mu.Unlock() - - return so.broadcastEnabled + return atomic.LoadUint32(&so.broadcastEnabled) != 0 } // SetBroadcast sets value for SO_BROADCAST option. func (so *SocketOptions) SetBroadcast(v bool) { + storeAtomicBool(&so.broadcastEnabled, v) +} + +// GetPassCred gets value for SO_PASSCRED option. +func (so *SocketOptions) GetPassCred() bool { + return atomic.LoadUint32(&so.passCredEnabled) != 0 +} + +// SetPassCred sets value for SO_PASSCRED option. +func (so *SocketOptions) SetPassCred(v bool) { + storeAtomicBool(&so.passCredEnabled, v) +} + +// GetNoChecksum gets value for SO_NO_CHECK option. +func (so *SocketOptions) GetNoChecksum() bool { + return atomic.LoadUint32(&so.noChecksumEnabled) != 0 +} + +// SetNoChecksum sets value for SO_NO_CHECK option. +func (so *SocketOptions) SetNoChecksum(v bool) { + storeAtomicBool(&so.noChecksumEnabled, v) +} + +// GetReuseAddress gets value for SO_REUSEADDR option. +func (so *SocketOptions) GetReuseAddress() bool { + return atomic.LoadUint32(&so.reuseAddressEnabled) != 0 +} + +// SetReuseAddress sets value for SO_REUSEADDR option. +func (so *SocketOptions) SetReuseAddress(v bool) { + storeAtomicBool(&so.reuseAddressEnabled, v) + so.handler.OnReuseAddressSet(v) +} + +// GetReusePort gets value for SO_REUSEPORT option. +func (so *SocketOptions) GetReusePort() bool { + return atomic.LoadUint32(&so.reusePortEnabled) != 0 +} + +// SetReusePort sets value for SO_REUSEPORT option. +func (so *SocketOptions) SetReusePort(v bool) { + storeAtomicBool(&so.reusePortEnabled, v) + so.handler.OnReusePortSet(v) +} + +// GetKeepAlive gets value for SO_KEEPALIVE option. +func (so *SocketOptions) GetKeepAlive() bool { + return atomic.LoadUint32(&so.keepAliveEnabled) != 0 +} + +// SetKeepAlive sets value for SO_KEEPALIVE option. +func (so *SocketOptions) SetKeepAlive(v bool) { + storeAtomicBool(&so.keepAliveEnabled, v) + so.handler.OnKeepAliveSet(v) +} + +// GetMulticastLoop gets value for IP_MULTICAST_LOOP option. +func (so *SocketOptions) GetMulticastLoop() bool { + return atomic.LoadUint32(&so.multicastLoopEnabled) != 0 +} + +// SetMulticastLoop sets value for IP_MULTICAST_LOOP option. +func (so *SocketOptions) SetMulticastLoop(v bool) { + storeAtomicBool(&so.multicastLoopEnabled, v) +} + +// GetReceiveTOS gets value for IP_RECVTOS option. +func (so *SocketOptions) GetReceiveTOS() bool { + return atomic.LoadUint32(&so.receiveTOSEnabled) != 0 +} + +// SetReceiveTOS sets value for IP_RECVTOS option. +func (so *SocketOptions) SetReceiveTOS(v bool) { + storeAtomicBool(&so.receiveTOSEnabled, v) +} + +// GetReceiveTClass gets value for IPV6_RECVTCLASS option. +func (so *SocketOptions) GetReceiveTClass() bool { + return atomic.LoadUint32(&so.receiveTClassEnabled) != 0 +} + +// SetReceiveTClass sets value for IPV6_RECVTCLASS option. +func (so *SocketOptions) SetReceiveTClass(v bool) { + storeAtomicBool(&so.receiveTClassEnabled, v) +} + +// GetReceivePacketInfo gets value for IP_PKTINFO option. +func (so *SocketOptions) GetReceivePacketInfo() bool { + return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0 +} + +// SetReceivePacketInfo sets value for IP_PKTINFO option. +func (so *SocketOptions) SetReceivePacketInfo(v bool) { + storeAtomicBool(&so.receivePacketInfoEnabled, v) +} + +// GetHeaderIncluded gets value for IP_HDRINCL option. +func (so *SocketOptions) GetHeaderIncluded() bool { + return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0 +} + +// SetHeaderIncluded sets value for IP_HDRINCL option. +func (so *SocketOptions) SetHeaderIncluded(v bool) { + storeAtomicBool(&so.hdrIncludedEnabled, v) +} + +// GetV6Only gets value for IPV6_V6ONLY option. +func (so *SocketOptions) GetV6Only() bool { + return atomic.LoadUint32(&so.v6OnlyEnabled) != 0 +} + +// SetV6Only sets value for IPV6_V6ONLY option. +// +// Preconditions: the backing TCP or UDP endpoint must be in initial state. +func (so *SocketOptions) SetV6Only(v bool) { + storeAtomicBool(&so.v6OnlyEnabled, v) +} + +// GetQuickAck gets value for TCP_QUICKACK option. +func (so *SocketOptions) GetQuickAck() bool { + return atomic.LoadUint32(&so.quickAckEnabled) != 0 +} + +// SetQuickAck sets value for TCP_QUICKACK option. +func (so *SocketOptions) SetQuickAck(v bool) { + storeAtomicBool(&so.quickAckEnabled, v) +} + +// GetDelayOption gets inverted value for TCP_NODELAY option. +func (so *SocketOptions) GetDelayOption() bool { + return atomic.LoadUint32(&so.delayOptionEnabled) != 0 +} + +// SetDelayOption sets inverted value for TCP_NODELAY option. +func (so *SocketOptions) SetDelayOption(v bool) { + storeAtomicBool(&so.delayOptionEnabled, v) + so.handler.OnDelayOptionSet(v) +} + +// GetCorkOption gets value for TCP_CORK option. +func (so *SocketOptions) GetCorkOption() bool { + return atomic.LoadUint32(&so.corkOptionEnabled) != 0 +} + +// SetCorkOption sets value for TCP_CORK option. +func (so *SocketOptions) SetCorkOption(v bool) { + storeAtomicBool(&so.corkOptionEnabled, v) + so.handler.OnCorkOptionSet(v) +} + +// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) GetReceiveOriginalDstAddress() bool { + return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0 +} + +// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { + storeAtomicBool(&so.receiveOriginalDstAddress, v) +} + +// GetLastError gets value for SO_ERROR option. +func (so *SocketOptions) GetLastError() *Error { + return so.handler.LastError() +} + +// GetOutOfBandInline gets value for SO_OOBINLINE option. +func (*SocketOptions) GetOutOfBandInline() bool { + return true +} + +// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not +// support disabling this option. +func (*SocketOptions) SetOutOfBandInline(bool) {} + +// GetLinger gets value for SO_LINGER option. +func (so *SocketOptions) GetLinger() LingerOption { so.mu.Lock() - defer so.mu.Unlock() + linger := so.linger + so.mu.Unlock() + return linger +} - so.broadcastEnabled = v +// SetLinger sets value for SO_LINGER option. +func (so *SocketOptions) SetLinger(linger LingerOption) { + so.mu.Lock() + so.linger = linger + so.mu.Unlock() } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index d09ebe7fa..9cc6074da 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "most_shards") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -112,7 +112,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], - shard_count = 20, + shard_count = most_shards, deps = [ ":stack", "//pkg/rand", @@ -120,6 +120,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", @@ -131,7 +132,6 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 9478f3fb7..cd423bf71 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) var _ AddressableEndpoint = (*AddressableEndpointState)(nil) // AddressableEndpointState is an implementation of an AddressableEndpoint. @@ -37,10 +36,6 @@ type AddressableEndpointState struct { endpoints map[tcpip.Address]*addressState primary []*addressState - - // groups holds the mapping between group addresses and the number of times - // they have been joined. - groups map[tcpip.Address]uint32 } } @@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { a.mu.Lock() defer a.mu.Unlock() a.mu.endpoints = make(map[tcpip.Address]*addressState) - a.mu.groups = make(map[tcpip.Address]uint32) -} - -// ReadOnlyAddressableEndpointState provides read-only access to an -// AddressableEndpointState. -type ReadOnlyAddressableEndpointState struct { - inner *AddressableEndpointState } -// AddrOrMatching returns an endpoint for the passed address that is consisdered -// bound to the wrapped AddressableEndpointState. +// GetAddress returns the AddressEndpoint for the passed address. // -// If addr is an exact match with an existing address, that address is returned. -// Otherwise, f is called with each address and the address that f returns true -// for is returned. -// -// Returns nil of no address matches. -func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - - if ep, ok := m.inner.mu.endpoints[addr]; ok { - if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { - return ep - } - } - - for _, ep := range m.inner.mu.endpoints { - if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { - return ep - } - } - - return nil -} - -// Lookup returns the AddressEndpoint for the passed address. +// GetAddress does not increment the address's reference count or check if the +// address is considered bound to the endpoint. // -// Returns nil if the passed address is not associated with the -// AddressableEndpointState. -func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Returns nil if the passed address is not associated with the endpoint. +func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() - ep, ok := m.inner.mu.endpoints[addr] + ep, ok := a.mu.endpoints[addr] if !ok { return nil } return ep } -// ForEach calls f for each address pair. +// ForEachEndpoint calls f for each address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() - for _, ep := range m.inner.mu.endpoints { + for _, ep := range a.mu.endpoints { if !f(ep) { return } @@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) // ForEachPrimaryEndpoint calls f for each primary address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - for _, ep := range m.inner.mu.primary { - f(ep) - } -} +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() -// ReadOnly returns a readonly reference to a. -func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { - return ReadOnlyAddressableEndpointState{inner: a} + for _, ep := range a.mu.primary { + if !f(ep) { + return + } + } } func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { @@ -335,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { a.mu.Lock() defer a.mu.Unlock() - - if _, ok := a.mu.groups[addr]; ok { - panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) - } - return a.removePermanentAddressLocked(addr) } @@ -471,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad return deprecatedEndpoint } -// AcquireAssignedAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { +// AcquireAssignedAddressOrMatching returns an address endpoint that is +// considered assigned to the addressable endpoint. +// +// If the address is an exact match with an existing address, that address is +// returned. Otherwise, if f is provided, f is called with each address and +// the address that f returns true for is returned. +// +// If there is no matching address, a temporary address will be returned if +// allowTemp is true. +// +// Regardless how the address was obtained, it will be acquired before it is +// returned. +func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { a.mu.Lock() defer a.mu.Unlock() @@ -488,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return addrState } + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } + } + } + if !allowTemp { return nil } @@ -520,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return ep } +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB) +} + // AcquireOutgoingPrimaryAddress implements AddressableEndpoint. func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { a.mu.RLock() @@ -588,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi return addrs } -// JoinGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */) - if err != nil { - return false, err - } - // We have no need for the address endpoint. - a.decAddressRefLocked(ep) - } - - a.mu.groups[group] = joins + 1 - return !ok, nil -} - -// LeaveGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - return false, tcpip.ErrBadLocalAddress - } - - if joins == 1 { - a.removeGroupAddressLocked(group) - delete(a.mu.groups, group) - return true, nil - } - - a.mu.groups[group] = joins - 1 - return false, nil -} - -// IsInGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { - a.mu.RLock() - defer a.mu.RUnlock() - _, ok := a.mu.groups[group] - return ok -} - -func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) { - if err := a.removePermanentAddressLocked(group); err != nil { - // removePermanentEndpointLocked would only return an error if group is - // not bound to the addressable endpoint, but we know it MUST be assigned - // since we have group in our map of groups. - panic(fmt.Sprintf("error removing group address = %s: %s", group, err)) - } -} - // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() defer a.mu.Unlock() - for group := range a.mu.groups { - a.removeGroupAddressLocked(group) - } - a.mu.groups = make(map[tcpip.Address]uint32) - for _, ep := range a.mu.endpoints { // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is // not a permanent address. diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 26787d0a3..140f146f6 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { ep.DecRef() } - group := tcpip.Address("\x02") - if added, err := s.JoinGroup(group); err != nil { - t.Fatalf("s.JoinGroup(%s): %s", group, err) - } else if !added { - t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) - } - if !s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) - } - s.Cleanup() - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } - } - if s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) } } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 6dc9e7859..5ec9b3411 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -309,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, + RemoteLinkAddress: r.RemoteLinkAddress(), LocalLinkAddress: r.LocalLinkAddress, Pkt: pkt, } @@ -333,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := fwdTestPacketInfo{ - Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), - } - - select { - case e.C <- p: - default: - } - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 73a01c2dd..03d7b4e0d 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) { } // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 { t.Fatalf("got NeighborSolicit = %d, want = 0", got) } } @@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) { if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) { } else if r.LocalAddress != addr1 { t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) { } // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } @@ -533,8 +540,8 @@ func TestDADResolve(t *testing.T) { // Make sure the right remote link address is used. snmc := header.SolicitedNodeAddr(addr1) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { + t.Errorf("got remote link address = %s, want = %s", got, want) } // Check NDP NS packet. @@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) } @@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(tgt) @@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: tgt, + DstAddr: header.IPv6AllNodesMulticastAddress, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, @@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate an address conflict. test.rxPkt(e, addr1) - stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) + stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived) if got := stat.Value(); got != 1 { t.Fatalf("got stat = %d, want = 1", got) } @@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) { } // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 { t.Errorf("got NeighborSolicit = %d, want <= 1", got) } }) @@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - raPayload := pkt.NDPPayload() + raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. binary.BigEndian.PutUint16(raPayload[2:], rl) @@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, }) return stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { NDPConfigs: ipv6.NDPConfigurations{ AutoGenTempGlobalAddresses: true, }, - NDPDisp: &ndpDisp, - AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, })}, }) @@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(addr); err != nil { t.Fatalf("ep.Connect(%+v): %s", addr, err) } @@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Bind(addr); err != nil { t.Fatalf("ep.Bind(%+v): %s", addr, err) } @@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) @@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + AutoGenLinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName @@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, + AutoGenLinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) { } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: true, + AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + // Make sure the right remote link address is used. + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want { + t.Errorf("got remote link address = %s, want = %s", got, want) + } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet") - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) - remaining-- + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) } + } - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) - waitForPkt(defaultAsyncPositiveEventTimeout) - } else { - waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) + waitForPkt(test.effectiveRtrSolicitInt) } + } - // Make sure the counter got properly - // incremented. - if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } - }) + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } + + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) + } } func TestStopStartSolicitingRouters(t *testing.T) { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 177bf5516..317f6871d 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -24,9 +24,16 @@ import ( const neighborCacheSize = 512 // max entries per interface +// NeighborStats holds metrics for the neighbor table. +type NeighborStats struct { + // FailedEntryLookups counts the number of lookups performed on an entry in + // Failed state. + FailedEntryLookups *tcpip.StatCounter +} + // neighborCache maps IP addresses to link addresses. It uses the Least // Recently Used (LRU) eviction strategy to implement a bounded cache for -// dynmically acquired entries. It contains the state machine and configuration +// dynamically acquired entries. It contains the state machine and configuration // for running Neighbor Unreachability Detection (NUD). // // There are two types of entries in the neighbor cache: @@ -175,14 +182,15 @@ func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { // entries returns all entries in the neighbor cache. func (n *neighborCache) entries() []NeighborEntry { - entries := make([]NeighborEntry, 0, len(n.cache)) n.mu.RLock() + defer n.mu.RUnlock() + + entries := make([]NeighborEntry, 0, len(n.cache)) for _, entry := range n.cache { entry.mu.RLock() entries = append(entries, entry.neigh) entry.mu.RUnlock() } - n.mu.RUnlock() return entries } @@ -226,6 +234,8 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd } // removeEntryLocked removes the specified entry from the neighbor cache. +// +// Prerequisite: n.mu and entry.mu MUST be locked. func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { if entry.neigh.State != Static { n.dynamic.lru.Remove(entry) diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index ed33418f3..732a299f7 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -80,17 +80,20 @@ func entryDiffOptsWithSort() []cmp.Option { func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - return &neighborCache{ + neigh := &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, nudDisp: nudDisp, }, - id: 1, + id: 1, + stats: makeNICStats(), }, state: NewNUDState(config, rng), cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } + neigh.nic.neigh = neigh + return neigh } // testEntryStore contains a set of IP to NeighborEntry mappings. diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 493e48031..32399b4f5 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -258,7 +258,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { case Failed: e.notifyWakersLocked() - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() { e.nic.neigh.removeEntryLocked(e) }) e.job.Schedule(config.UnreachableTime) @@ -347,9 +347,10 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { e.setStateLocked(Delay) e.dispatchChangeEventLocked() - case Incomplete, Reachable, Delay, Probe, Static, Failed: + case Incomplete, Reachable, Delay, Probe, Static: // Do nothing - + case Failed: + e.nic.stats.Neighbor.FailedEntryLookups.Increment() default: panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } @@ -511,3 +512,23 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() { panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } } + +// doubleLock combines two locks into one while maintaining lock ordering. +// +// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed +// neighbor is allowed. +type doubleLock struct { + first, second sync.Locker +} + +// Lock locks both locks in order: first then second. +func (l *doubleLock) Lock() { + l.first.Lock() + l.second.Lock() +} + +// Unlock unlocks both locks in reverse order: second then first. +func (l *doubleLock) Unlock() { + l.second.Unlock() + l.first.Unlock() +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index c2b763325..c497d3932 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -89,7 +89,7 @@ func eventDiffOptsWithSort() []cmp.Option { // | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Stale | Stale | Override confirmation | Update LinkAddr | Changed | // | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | -// | Stale | Delay | Packet sent | | Changed | +// | Stale | Delay | Packet queued | | Changed | // | Delay | Reachable | Upper-layer confirmation | | Changed | // | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | // | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | @@ -101,6 +101,7 @@ func eventDiffOptsWithSort() []cmp.Option { // | Probe | Stale | Probe or confirmation w/ different address | | Changed | // | Probe | Probe | Retransmit timer expired | Send probe | Changed | // | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Failed | Failed | Packet queued | | | // | Failed | | Unreachability timer expired | Delete entry | | type testEntryEventType uint8 @@ -228,6 +229,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e clock: clock, nudDisp: &disp, }, + stats: makeNICStats(), } nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), @@ -3433,6 +3435,146 @@ func TestEntryProbeToFailed(t *testing.T) { nudDisp.mu.Unlock() } +func TestEntryFailedToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // Verify the cache contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { + t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + } + + // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in + // their expected state. + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + clock.Advance(waitFor) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + failedLookups := e.nic.stats.Neighbor.FailedEntryLookups + if got := failedLookups.Value(); got != 0 { + t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got) + } + + e.mu.Lock() + // Verify queuing a packet to the entry immediately fails. + e.handlePacketQueuedLocked(entryTestAddr2) + state := e.neigh.State + e.mu.Unlock() + if state != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", state, Failed) + } + + if got := failedLookups.Value(); got != 1 { + t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got) + } +} + func TestEntryFailedGetsDeleted(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3e6ceff28..5d037a27e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -54,18 +54,20 @@ type NIC struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + // packetEPs is protected by mu, but the contained packetEndpointList are + // not. + packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList } } -// NICStats includes transmitted and received stats. +// NICStats hold statistics for a NIC. type NICStats struct { Tx DirectionStats Rx DirectionStats DisabledRx DirectionStats + + Neighbor NeighborStats } func makeNICStats() NICStats { @@ -80,6 +82,39 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } +type packetEndpointList struct { + mu sync.RWMutex + + // eps is protected by mu, but the contained PacketEndpoint values are not. + eps []PacketEndpoint +} + +func (p *packetEndpointList) add(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.eps = append(p.eps, ep) +} + +func (p *packetEndpointList) remove(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + for i, epOther := range p.eps { + if epOther == ep { + p.eps = append(p.eps[:i], p.eps[i+1:]...) + break + } + } +} + +// forEach calls fn with each endpoints in p while holding the read lock on p. +func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { + p.mu.RLock() + defer p.mu.RUnlock() + for _, ep := range p.eps { + fn(ep) + } +} + // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -100,7 +135,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -123,11 +158,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = new(packetEndpointList) } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = nil + nic.mu.packetEPs[netNum] = new(packetEndpointList) nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } @@ -170,7 +205,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -182,6 +217,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. @@ -265,7 +304,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } return err @@ -277,9 +316,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // WritePacketToRemote implements NetworkInterface. func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { r := Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + NetProto: protocol, } + r.ResolveWith(remoteLinkAddr) return n.writePacket(&r, gso, protocol, pkt) } @@ -561,8 +600,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address return tcpip.ErrNotSupported } - _, err := gep.JoinGroup(addr) - return err + return gep.JoinGroup(addr) } // leaveGroup decrements the count for the given multicast address, and when it @@ -578,11 +616,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres return tcpip.ErrNotSupported } - if _, err := gep.LeaveGroup(addr); err != nil { - return err - } - - return nil + return gep.LeaveGroup(addr) } // isInGroup returns true if n has joined the multicast group addr. @@ -637,15 +671,23 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? - packetEPs := n.mu.packetEPs[protocol] - // Add any other packet type sockets that may be listening for all protocols. - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) + protoEPs := n.mu.packetEPs[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + // Deliver to interested packet endpoints without holding NIC lock. + deliverPacketEPs := func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketHost ep.HandlePacket(n.id, local, protocol, p) } + if protoEPs != nil { + protoEPs.forEach(deliverPacketEPs) + } + if anyEPs != nil { + anyEPs.forEach(deliverPacketEPs) + } // Parse headers. netProto := n.stack.NetworkProtocolInstance(protocol) @@ -686,16 +728,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + eps := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + eps.forEach(func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) - } + }) } // DeliverTransportPacket delivers the packets to the appropriate transport @@ -848,7 +891,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa if !ok { return tcpip.ErrNotSupported } - n.mu.packetEPs[netProto] = append(eps, ep) + eps.add(ep) return nil } @@ -861,13 +904,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep if !ok { return } - - for i, epOther := range eps { - if epOther == ep { - n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) - return - } - } + eps.remove(ep) } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2cb13c6fa..b334e27c4 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -259,15 +259,6 @@ const ( PacketLoop ) -// NetOptions is an interface that allows us to pass network protocol specific -// options through the Stack layer code. -type NetOptions interface { - // SizeWithPadding returns the amount of memory that must be allocated to - // hold the options given that the value must be rounded up to the next - // multiple of 4 bytes. - SizeWithPadding() int -} - // NetworkHeaderParams are the header parameters given as input by the // transport endpoint to the network. type NetworkHeaderParams struct { @@ -279,10 +270,6 @@ type NetworkHeaderParams struct { // TOS refers to TypeOfService or TrafficClass field of the IP-header. TOS uint8 - - // Options is a set of options to add to a network header (or nil). - // It will be protocol specific opaque information from higher layers. - Options NetOptions } // GroupAddressableEndpoint is an endpoint that supports group addressing. @@ -291,14 +278,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - // - // Returns true if the group was newly joined. - JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + JoinGroup(group tcpip.Address) *tcpip.Error // LeaveGroup attempts to leave the specified group. - // - // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. - LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + LeaveGroup(group tcpip.Address) *tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool @@ -739,10 +722,6 @@ type LinkEndpoint interface { // endpoint. Capabilities() LinkEndpointCapabilities - // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. It takes ownership of vv. - WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error - // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 53cb6694f..de5fe6ffe 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -18,19 +18,22 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) // Route represents a route through the networking stack to a given destination. +// +// It is safe to call Route's methods from multiple goroutines. +// +// The exported fields are immutable. +// +// TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { // RemoteAddress is the final destination of the route. RemoteAddress tcpip.Address - // RemoteLinkAddress is the link-layer (MAC) address of the - // final destination of the route. - RemoteLinkAddress tcpip.LinkAddress - // LocalAddress is the local address where the route starts. LocalAddress tcpip.Address @@ -52,8 +55,16 @@ type Route struct { // address's assigned status without the NIC. localAddressNIC *NIC - // localAddressEndpoint is the local address this route is associated with. - localAddressEndpoint AssignableAddressEndpoint + mu struct { + sync.RWMutex + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint + + // remoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + remoteLinkAddress tcpip.LinkAddress + } // outgoingNIC is the interface this route uses to write packets. outgoingNIC *NIC @@ -71,22 +82,24 @@ type Route struct { // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { - addrWithPrefix := addressEndpoint.AddressWithPrefix() +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { + if len(localAddr) == 0 { + localAddr = addressEndpoint.AddressWithPrefix().Address + } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { addressEndpoint.DecRef() - return Route{} + return nil } // If no remote address is provided, use the local address. if len(remoteAddr) == 0 { - remoteAddr = addrWithPrefix.Address + remoteAddr = localAddr } r := makeRoute( netProto, - addrWithPrefix.Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -99,8 +112,8 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // broadcast it. if len(gateway) > 0 { r.NextHop = gateway - } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.ResolveWith(header.EthernetBroadcastAddress) } return r @@ -108,11 +121,15 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { if localAddressNIC.stack != outgoingNIC.stack { panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } + if len(localAddr) == 0 { + localAddr = localAddressEndpoint.AddressWithPrefix().Address + } + loop := PacketOut // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the @@ -133,18 +150,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { - r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - localAddressNIC: localAddressNIC, - localAddressEndpoint: localAddressEndpoint, - outgoingNIC: outgoingNIC, - Loop: loop, +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { + r := &Route{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, + Loop: loop, } + r.mu.Lock() + r.mu.localAddressEndpoint = localAddressEndpoint + r.mu.Unlock() + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes @@ -159,7 +179,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr // provided AssignableAddressEndpoint. // // A local route is a route to a destination that is local to the stack. -func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route { loop := PacketLoop // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the // link endpoint level. We can remove this check once loopback interfaces @@ -170,6 +190,14 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } +// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in +// the route. +func (r *Route) RemoteLinkAddress() tcpip.LinkAddress { + r.mu.RLock() + defer r.mu.RUnlock() + return r.mu.remoteLinkAddress +} + // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { return r.outgoingNIC.ID() @@ -231,7 +259,9 @@ func (r *Route) GSOMaxSize() uint32 { // ResolveWith immediately resolves a route with the specified remote link // address. func (r *Route) ResolveWith(addr tcpip.LinkAddress) { - r.RemoteLinkAddress = addr + r.mu.Lock() + defer r.mu.Unlock() + r.mu.remoteLinkAddress = addr } // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in @@ -244,7 +274,10 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { // // The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { - if !r.IsResolutionRequired() { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.isResolutionRequiredRLocked() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. return nil, nil @@ -254,7 +287,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if nextAddr == "" { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { - r.RemoteLinkAddress = r.LocalLinkAddress + r.mu.remoteLinkAddress = r.LocalLinkAddress return nil, nil } nextAddr = r.RemoteAddress @@ -272,7 +305,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if err != nil { return ch, err } - r.RemoteLinkAddress = entry.LinkAddr + r.mu.remoteLinkAddress = entry.LinkAddr return nil, nil } @@ -280,7 +313,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if err != nil { return ch, err } - r.RemoteLinkAddress = linkAddr + r.mu.remoteLinkAddress = linkAddr return nil, nil } @@ -309,7 +342,13 @@ func (r *Route) local() bool { // // The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isResolutionRequiredRLocked() +} + +func (r *Route) isResolutionRequiredRLocked() bool { + if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { return false } @@ -317,11 +356,18 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isValidForOutgoing() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isValidForOutgoingRLocked() +} + +func (r *Route) isValidForOutgoingRLocked() bool { if !r.outgoingNIC.Enabled() { return false } - if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + localAddressEndpoint := r.mu.localAddressEndpoint + if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) { return false } @@ -375,37 +421,44 @@ func (r *Route) MTU() uint32 { // Release frees all resources associated with the route. func (r *Route) Release() { - if r.localAddressEndpoint != nil { - r.localAddressEndpoint.DecRef() - r.localAddressEndpoint = nil + r.mu.Lock() + defer r.mu.Unlock() + + if r.mu.localAddressEndpoint != nil { + r.mu.localAddressEndpoint.DecRef() + r.mu.localAddressEndpoint = nil } } // Clone clones the route. -func (r *Route) Clone() Route { - if r.localAddressEndpoint != nil { - if !r.localAddressEndpoint.IncRef() { - panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) - } +func (r *Route) Clone() *Route { + r.mu.RLock() + defer r.mu.RUnlock() + + newRoute := &Route{ + RemoteAddress: r.RemoteAddress, + LocalAddress: r.LocalAddress, + LocalLinkAddress: r.LocalLinkAddress, + NextHop: r.NextHop, + NetProto: r.NetProto, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, } - return *r -} -// MakeLoopedRoute duplicates the given route with special handling for routes -// used for sending multicast or broadcast packets. In those cases the -// multicast/broadcast address is the remote address when sending out, but for -// incoming (looped) packets it becomes the local address. Similarly, the local -// interface address that was the local address going out becomes the remote -// address coming in. This is different to unicast routes where local and -// remote addresses remain the same as they identify location (local vs remote) -// not direction (source vs destination). -func (r *Route) MakeLoopedRoute() Route { - l := r.Clone() - if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { - l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress - l.RemoteLinkAddress = l.LocalLinkAddress + newRoute.mu.Lock() + defer newRoute.mu.Unlock() + newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint + if newRoute.mu.localAddressEndpoint != nil { + if !newRoute.mu.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress)) + } } - return l + newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress + + return newRoute } // Stack returns the instance of the Stack that owns this route. @@ -418,7 +471,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.localAddressEndpoint.Subnet() + r.mu.RLock() + localAddressEndpoint := r.mu.localAddressEndpoint + r.mu.RUnlock() + if localAddressEndpoint == nil { + return false + } + + subnet := localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } @@ -428,27 +488,3 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } - -// isInboundBroadcast returns true if the route is for an inbound broadcast -// packet. -func (r *Route) isInboundBroadcast() bool { - // Only IPv4 has a notion of broadcast. - return r.isV4Broadcast(r.LocalAddress) -} - -// ReverseRoute returns new route with given source and destination address. -func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { - return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - localAddressNIC: r.localAddressNIC, - localAddressEndpoint: r.localAddressEndpoint, - outgoingNIC: r.outgoingNIC, - linkCache: r.linkCache, - linkRes: r.linkRes, - } -} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e0025e0a9..026d330c4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -171,6 +171,9 @@ type TCPSenderState struct { // Outstanding is the number of packets in flight. Outstanding int + // SackedOut is the number of packets which have been selectively acked. + SackedOut int + // SndWnd is the send window size in bytes. SndWnd seqnum.Size @@ -1118,6 +1121,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } +// AddAddressWithPrefix is the same as AddAddress, but allows you to specify +// the address prefix. +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error { + ap := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: addr, + } + return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) +} + // AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { @@ -1208,10 +1221,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP // from the specified NIC. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) if localAddressEndpoint == nil { - return Route{}, false + return nil } var outgoingNIC *NIC @@ -1235,12 +1248,12 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // route. if outgoingNIC == nil { localAddressEndpoint.DecRef() - return Route{}, false + return nil } r := makeLocalRoute( netProto, - localAddressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -1249,10 +1262,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re if r.IsOutboundBroadcast() { r.Release() - return Route{}, false + return nil } - return r, true + return r } // findLocalRouteRLocked returns a local route. @@ -1261,26 +1274,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // is, a local route is a route where packets never have to leave the stack. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { if len(localAddr) == 0 { localAddr = remoteAddr } if localAddressNICID == 0 { for _, localAddressNIC := range s.nics { - if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { - return r, true + if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil { + return r } } - return Route{}, false + return nil } if localAddressNIC, ok := s.nics[localAddressNICID]; ok { return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) } - return Route{}, false + return nil } // FindRoute creates a route to the given destination address, leaving through @@ -1294,7 +1307,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1305,7 +1318,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) if s.handleLocal && !isMulticast && !isLocalBroadcast { - if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil { return r, nil } } @@ -1317,7 +1330,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { return makeRoute( netProto, - addressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, nic, /* outboundNIC */ nic, /* localAddressNIC*/ @@ -1329,9 +1342,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal @@ -1354,8 +1367,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if needRoute { gateway = route.Gateway } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) - if r == (Route{}) { + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) + if r == nil { panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) } return r, nil @@ -1391,13 +1404,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 { if aNIC, ok := s.nics[id]; ok { if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } } - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if id == 0 { @@ -1409,7 +1422,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } @@ -1417,12 +1430,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if header.IsV6LoopbackAddress(remoteAddr) { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1810,49 +1823,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip nic.unregisterPacketEndpoint(netProto, ep) } -// WritePacket writes data directly to the specified NIC. It adds an ethernet -// header based on the arguments. -func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { - s.mu.Lock() - nic, ok := s.nics[nicID] - s.mu.Unlock() - if !ok { - return tcpip.ErrUnknownDevice - } - - // Add our own fake ethernet header. - ethFields := header.EthernetFields{ - SrcAddr: nic.LinkEndpoint.LinkAddress(), - DstAddr: dst, - Type: netProto, - } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - vv := buffer.View(fakeHeader).ToVectorisedView() - vv.Append(payload) - - if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { - return err - } - - return nil -} - -// WriteRawPacket writes data directly to the specified NIC without adding any -// headers. -func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { +// WritePacketToRemote writes a payload on the specified NIC using the provided +// network protocol and remote link address. +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice } - - if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { - return err - } - - return nil + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(nic.MaxHeaderLength()), + Data: payload, + }) + return nic.WritePacketToRemote(remote, nil, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the @@ -1912,7 +1896,6 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - // TODO: notify network of subscription via igmp protocol. s.mu.RLock() defer s.mu.RUnlock() @@ -2159,3 +2142,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { } return protos } + +func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + return false + } + + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + return subnet.IsBroadcast(addr) +} + +// IsSubnetBroadcast returns true if the provided address is a subnet-local +// broadcast address on the specified NIC and protocol. +// +// Returns false if the NIC is unknown or if the protocol is unknown or does +// not support addressing. +// +// If the NIC is not specified, the stack will check all NICs. +func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + if nicID != 0 { + nic, ok := s.nics[nicID] + if !ok { + return false + } + + return isSubnetBroadcastOnNIC(nic, protocol, addr) + } + + for _, nic := range s.nics { + if isSubnetBroadcastOnNIC(nic, protocol, addr) { + return true + } + } + + return false +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 61db3164b..457990945 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,7 +27,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -407,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -425,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En } } -func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() ep.Drain() if err := send(r, payload); err != nil { @@ -436,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer. } } -func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) @@ -1563,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) { // testSendTo(t, s, remoteAddr, ep, nil) } -func verifyRoute(gotRoute, wantRoute stack.Route) error { +func verifyRoute(gotRoute, wantRoute *stack.Route) error { if gotRoute.LocalAddress != wantRoute.LocalAddress { return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) } if gotRoute.RemoteAddress != wantRoute.RemoteAddress { return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) } - if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { - return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want { + return fmt.Errorf("bad remote link address: got %s, want = %s", got, want) } if gotRoute.NextHop != wantRoute.NextHop { return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) @@ -1603,7 +1602,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1657,7 +1656,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1667,7 +1666,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } @@ -1683,7 +1682,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -2407,9 +2406,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, + AutoGenLinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, })}, } @@ -2502,8 +2501,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, + AutoGenLinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, })}, } @@ -2536,9 +2535,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpConfigs := ipv6.DefaultNDPConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + AutoGenLinkLocal: true, + NDPDisp: &ndpDisp, })}, } @@ -3351,11 +3350,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { remNetSubnetBcast := remNetSubnet.Broadcast() tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - expectedRoute stack.Route + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedLocalAddress tcpip.Address + expectedRemoteAddress tcpip.Address + expectedRemoteLinkAddress tcpip.LinkAddress + expectedNextHop tcpip.Address + expectedNetProto tcpip.NetworkProtocolNumber + expectedLoop stack.PacketLooping }{ // Broadcast to a locally attached subnet populates the broadcast MAC. { @@ -3370,14 +3374,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: ipv4SubnetBcast, - RemoteLinkAddress: header.EthernetBroadcastAddress, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut | stack.PacketLoop, - }, + remoteAddr: ipv4SubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: ipv4SubnetBcast, + expectedRemoteLinkAddress: header.EthernetBroadcastAddress, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut | stack.PacketLoop, }, // Broadcast to a locally attached /31 subnet does not populate the // broadcast MAC. @@ -3393,13 +3395,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet31Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix31.Address, - RemoteAddress: ipv4Subnet31Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet31Bcast, + expectedLocalAddress: ipv4AddrPrefix31.Address, + expectedRemoteAddress: ipv4Subnet31Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a locally attached /32 subnet does not populate the // broadcast MAC. @@ -3415,13 +3415,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet32Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix32.Address, - RemoteAddress: ipv4Subnet32Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet32Bcast, + expectedLocalAddress: ipv4AddrPrefix32.Address, + expectedRemoteAddress: ipv4Subnet32Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // IPv6 has no notion of a broadcast. { @@ -3436,13 +3434,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv6SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv6Addr.Address, - RemoteAddress: ipv6SubnetBcast, - NetProto: header.IPv6ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv6SubnetBcast, + expectedLocalAddress: ipv6Addr.Address, + expectedRemoteAddress: ipv6SubnetBcast, + expectedNetProto: header.IPv6ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a remote subnet in the route table is send to the next-hop // gateway. @@ -3459,14 +3455,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to an unknown subnet follows the default route. Note that this // is essentially just routing an unknown destination IP, because w/o any @@ -3484,14 +3478,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, } @@ -3520,10 +3512,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */) + if err != nil { t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) - } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { - t.Errorf("route mismatch (-want +got):\n%s", diff) + } + if r.LocalAddress != test.expectedLocalAddress { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress) + } + if r.RemoteAddress != test.expectedRemoteAddress { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress) + } + if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress { + t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress) + } + if r.NextHop != test.expectedNextHop { + t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop) + } + if r.NetProto != test.expectedNetProto { + t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto) + } + if r.Loop != test.expectedLoop { + t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop) } }) } @@ -4091,10 +4100,12 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if r != nil { + defer r.Release() + } if err != test.findRouteErr { t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) } - defer r.Release() if test.findRouteErr != nil { return @@ -4152,3 +4163,63 @@ func TestFindRouteWithForwarding(t *testing.T) { }) } } + +func TestWritePacketToRemote(t *testing.T) { + const nicID = 1 + const MTU = 1280 + e := channel.New(1, MTU, linkAddr1) + s := stack.New(stack.Options{}) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("CreateNIC(%d) = %s", nicID, err) + } + tests := []struct { + name string + protocol tcpip.NetworkProtocolNumber + payload []byte + }{ + { + name: "SuccessIPv4", + protocol: header.IPv4ProtocolNumber, + payload: []byte{1, 2, 3, 4}, + }, + { + name: "SuccessIPv6", + protocol: header.IPv6ProtocolNumber, + payload: []byte{5, 6, 7, 8}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) + } + + pkt, ok := e.Read() + if got, want := ok, true; got != want { + t.Fatalf("e.Read() = %t, want %t", got, want) + } + if got, want := pkt.Proto, test.protocol; got != want { + t.Fatalf("pkt.Proto = %d, want %d", got, want) + } + if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want) + } + if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) + } + }) + } + + t.Run("InvalidNICID", func(t *testing.T) { + if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + } + pkt, ok := e.Read() + if got, want := ok, false; got != want { + t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) + } + }) +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 41a8e5ad0..a692af20b 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -141,11 +141,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -307,9 +307,7 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { - t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) - } + ep.SocketOptions().SetReusePort(endpoint.reuse) bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5b9043d85..66eb562ba 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -38,14 +38,15 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler proto *fakeTransportProtocol peerAddr tcpip.Address - route stack.Route + route *stack.Route uniqueID uint64 // acceptQueue is non-nil iff bound. - acceptQueue []fakeTransportEndpoint + acceptQueue []*fakeTransportEndpoint // ops is used to set and get socket options. ops tcpip.SocketOptions @@ -64,8 +65,11 @@ func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { return &f.ops } + func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep.ops.InitHandler(ep) + return ep } func (f *fakeTransportEndpoint) Abort() { @@ -105,8 +109,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // SetSockOpt sets a socket option. Currently not supported. @@ -114,21 +118,11 @@ func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Erro return tcpip.ErrInvalidEndpointState } -// SetSockOptBool sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - // SetSockOptInt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -189,7 +183,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai if len(f.acceptQueue) == 0 { return nil, nil, nil } - a := &f.acceptQueue[0] + a := f.acceptQueue[0] f.acceptQueue = f.acceptQueue[1:] return a, nil, nil } @@ -206,7 +200,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { ); err != nil { return err } - f.acceptQueue = []fakeTransportEndpoint{} + f.acceptQueue = []*fakeTransportEndpoint{} return nil } @@ -232,7 +226,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * } route.ResolveWith(pkt.SourceLinkAddress()) - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + ep := &fakeTransportEndpoint{ TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -240,7 +234,9 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * proto: f.proto, peerAddr: route.RemoteAddress, route: route, - }) + } + ep.ops.InitHandler(ep) + f.acceptQueue = append(f.acceptQueue, ep) } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f9e83dd1c..45fa62720 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -49,8 +49,9 @@ const ipv4AddressSize = 4 // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // -// Note: to support save / restore, it is important that all tcpip errors have -// distinct error messages. +// All errors must have unique msg strings. +// +// +stateify savable type Error struct { msg string @@ -247,6 +248,16 @@ func (a Address) WithPrefix() AddressWithPrefix { } } +// Unspecified returns true if the address is unspecified. +func (a Address) Unspecified() bool { + for _, b := range a { + if b != 0 { + return false + } + } + return true +} + // AddressMask is a bitmask for an address. type AddressMask string @@ -481,6 +492,14 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is + // set. + HasOriginalDstAddress bool + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress FullAddress } // PacketOwner is used to get UID and GID of the packet. @@ -535,7 +554,7 @@ type Endpoint interface { // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. - Peek([][]byte) (int64, ControlMessages, *Error) + Peek([][]byte) (int64, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -593,10 +612,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt SettableSocketOption) *Error - // SetSockOptBool sets a socket option, for simple cases where a value - // has the bool type. - SetSockOptBool(opt SockOptBool, v bool) *Error - // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. SetSockOptInt(opt SockOptInt, v int) *Error @@ -604,10 +619,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt GettableSocketOption) *Error - // GetSockOptBool gets a socket option for simple cases where a return - // value has the bool type. - GetSockOptBool(SockOptBool) (bool, *Error) - // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. GetSockOptInt(SockOptInt) (int, *Error) @@ -694,79 +705,6 @@ type WriteOptions struct { Atomic bool } -// SockOptBool represents socket options which values have the bool type. -type SockOptBool int - -const ( - // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if - // data should be held until segments are full by the TCP transport - // protocol. - CorkOption SockOptBool = iota - - // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if - // data should be sent out immediately by the transport protocol. For - // TCP, it determines if the Nagle algorithm is on or off. - DelayOption - - // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to - // specify whether TCP keepalive is enabled for this socket. - KeepaliveEnabledOption - - // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to - // specify whether multicast packets sent over a non-loopback interface - // will be looped back. - MulticastLoopOption - - // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify - // whether UDP checksum is disabled for this socket. - NoChecksumOption - - // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify - // whether SCM_CREDENTIALS socket control messages are enabled. - // - // Only supported on Unix sockets. - PasscredOption - - // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. - QuickAckOption - - // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to - // specify if the IPV6_TCLASS ancillary message is passed with incoming - // packets. - ReceiveTClassOption - - // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify - // if the TOS ancillary message is passed with incoming packets. - ReceiveTOSOption - - // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to - // specify if more inforamtion is provided with incoming packets such as - // interface index and address. - ReceiveIPPacketInfoOption - - // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to - // specify whether Bind() should allow reuse of local address. - ReuseAddressOption - - // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit - // multiple sockets to be bound to an identical socket address. - ReusePortOption - - // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify - // whether an IPv6 socket is to be restricted to sending and receiving - // IPv6 packets only. - V6OnlyOption - - // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw - // endpoint that all packets being written have an IP header and the - // endpoint should not attach an IP header. - IPHdrIncludedOption - - // AcceptConnOption is used by GetSockOptBool to indicate if the - // socket is a listening socket. - AcceptConnOption -) - // SockOptInt represents socket options which values have the int type. type SockOptInt int @@ -1158,14 +1096,6 @@ type RemoveMembershipOption MembershipOption func (*RemoveMembershipOption) isSettableSocketOption() {} -// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP out-of-band data is delivered along with the normal in-band data. -type OutOfBandInlineOption int - -func (*OutOfBandInlineOption) isGettableSocketOption() {} - -func (*OutOfBandInlineOption) isSettableSocketOption() {} - // SocketDetachFilterOption is used by SetSockOpt to detach a previously attached // classic BPF filter on a given endpoint. type SocketDetachFilterOption int @@ -1215,10 +1145,6 @@ type LingerOption struct { Timeout time.Duration } -func (*LingerOption) isGettableSocketOption() {} - -func (*LingerOption) isSettableSocketOption() {} - // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable @@ -1389,6 +1315,18 @@ type ICMPv6PacketStats struct { // RedirectMsg is the total number of ICMPv6 redirect message packets // counted. RedirectMsg *StatCounter + + // MulticastListenerQuery is the total number of Multicast Listener Query + // messages counted. + MulticastListenerQuery *StatCounter + + // MulticastListenerReport is the total number of Multicast Listener Report + // messages counted. + MulticastListenerReport *StatCounter + + // MulticastListenerDone is the total number of Multicast Listener Done + // messages counted. + MulticastListenerDone *StatCounter } // ICMPv4SentPacketStats collects outbound ICMPv4-specific stats. @@ -1430,6 +1368,10 @@ type ICMPv6SentPacketStats struct { type ICMPv6ReceivedPacketStats struct { ICMPv6PacketStats + // Unrecognized is the total number of ICMPv6 packets received that the + // transport layer does not know how to parse. + Unrecognized *StatCounter + // Invalid is the total number of ICMPv6 packets received that the // transport layer could not parse. Invalid *StatCounter @@ -1439,25 +1381,90 @@ type ICMPv6ReceivedPacketStats struct { RouterOnlyPacketsDroppedByHost *StatCounter } -// ICMPStats collects ICMP-specific stats (both v4 and v6). -type ICMPStats struct { +// ICMPv4Stats collects ICMPv4-specific stats. +type ICMPv4Stats struct { // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type // and a single count of packets which failed to write to the link // layer. - V4PacketsSent ICMPv4SentPacketStats + PacketsSent ICMPv4SentPacketStats // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4 // packet type and a single count of invalid packets received. - V4PacketsReceived ICMPv4ReceivedPacketStats + PacketsReceived ICMPv4ReceivedPacketStats +} +// ICMPv6Stats collects ICMPv6-specific stats. +type ICMPv6Stats struct { // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type // and a single count of packets which failed to write to the link // layer. - V6PacketsSent ICMPv6SentPacketStats + PacketsSent ICMPv6SentPacketStats // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6 // packet type and a single count of invalid packets received. - V6PacketsReceived ICMPv6ReceivedPacketStats + PacketsReceived ICMPv6ReceivedPacketStats +} + +// ICMPStats collects ICMP-specific stats (both v4 and v6). +type ICMPStats struct { + // V4 contains the ICMPv4-specifics stats. + V4 ICMPv4Stats + + // V6 contains the ICMPv4-specifics stats. + V6 ICMPv6Stats +} + +// IGMPPacketStats enumerates counts for all IGMP packet types. +type IGMPPacketStats struct { + // MembershipQuery is the total number of Membership Query messages counted. + MembershipQuery *StatCounter + + // V1MembershipReport is the total number of Version 1 Membership Report + // messages counted. + V1MembershipReport *StatCounter + + // V2MembershipReport is the total number of Version 2 Membership Report + // messages counted. + V2MembershipReport *StatCounter + + // LeaveGroup is the total number of Leave Group messages counted. + LeaveGroup *StatCounter +} + +// IGMPSentPacketStats collects outbound IGMP-specific stats. +type IGMPSentPacketStats struct { + IGMPPacketStats + + // Dropped is the total number of IGMP packets dropped. + Dropped *StatCounter +} + +// IGMPReceivedPacketStats collects inbound IGMP-specific stats. +type IGMPReceivedPacketStats struct { + IGMPPacketStats + + // Invalid is the total number of IGMP packets received that IGMP could not + // parse. + Invalid *StatCounter + + // ChecksumErrors is the total number of IGMP packets dropped due to bad + // checksums. + ChecksumErrors *StatCounter + + // Unrecognized is the total number of unrecognized messages counted, these + // are silently ignored for forward-compatibilty. + Unrecognized *StatCounter +} + +// IGMPStats colelcts IGMP-specific stats. +type IGMPStats struct { + // IGMPSentPacketStats contains counts of sent packets by IGMP packet type + // and a single count of invalid packets received. + PacketsSent IGMPSentPacketStats + + // IGMPReceivedPacketStats contains counts of received packets by IGMP packet + // type and a single count of invalid packets received. + PacketsReceived IGMPReceivedPacketStats } // IPStats collects IP-specific stats (both v4 and v6). @@ -1665,6 +1672,9 @@ type Stats struct { // ICMP breaks out ICMP-specific stats (both v4 and v6). ICMP ICMPStats + // IGMP breaks out IGMP-specific stats. + IGMP IGMPStats + // IP breaks out IP-specific stats (both v4 and v6). IP IPStats diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 1c8e2bc34..c461da137 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -226,3 +226,47 @@ func TestAddressWithPrefixSubnet(t *testing.T) { } } } + +func TestAddressUnspecified(t *testing.T) { + tests := []struct { + addr Address + unspecified bool + }{ + { + addr: "", + unspecified: true, + }, + { + addr: "\x00", + unspecified: true, + }, + { + addr: "\x01", + unspecified: false, + }, + { + addr: "\x00\x00", + unspecified: true, + }, + { + addr: "\x01\x00", + unspecified: false, + }, + { + addr: "\x00\x01", + unspecified: false, + }, + { + addr: "\x01\x01", + unspecified: false, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) { + if got := test.addr.Unspecified(); got != test.unspecified { + t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 9b0f3b675..800025fb9 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -25,6 +25,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 421da1add..baaa741cd 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -70,8 +71,8 @@ func TestInitialLoopbackAddresses(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDispatcher{}, - AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDispatcher{}, + AutoGenLinkLocal: true, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(nicID tcpip.NICID, nicName string) string { t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) @@ -93,9 +94,10 @@ func TestInitialLoopbackAddresses(t *testing.T) { } } -// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address. -func TestLoopbackAcceptAllInSubnet(t *testing.T) { +// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers +// itself bound to all addresses in the subnet of an assigned address and UDP +// traffic is sent/received correctly. +func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { const ( nicID = 1 localPort = 80 @@ -107,7 +109,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr, } - ipv4Bytes := []byte(ipv4Addr.Address) + ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) ipv4Bytes[len(ipv4Bytes)-1]++ otherIPv4Address := tcpip.Address(ipv4Bytes) @@ -129,7 +131,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { { name: "IPv4 bind to wildcard and send to assigned address", addAddress: ipv4ProtocolAddress, - dstAddr: ipv4Addr.Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, expectRx: true, }, { @@ -148,7 +150,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { name: "IPv4 bind to other subnet-local address and send to assigned address", addAddress: ipv4ProtocolAddress, bindAddr: otherIPv4Address, - dstAddr: ipv4Addr.Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, expectRx: false, }, { @@ -161,7 +163,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { { name: "IPv4 bind to assigned address and send to other subnet-local address", addAddress: ipv4ProtocolAddress, - bindAddr: ipv4Addr.Address, + bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, dstAddr: otherIPv4Address, expectRx: false, }, @@ -236,13 +238,17 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) } - if gotPayload, _, err := rep.Read(nil); test.expectRx { + var addr tcpip.FullAddress + if gotPayload, _, err := rep.Read(&addr); test.expectRx { if err != nil { - t.Fatalf("reep.Read(nil): %s", err) + t.Fatalf("reep.Read(_): %s", err) } if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } + if addr.Addr != test.addAddress.AddressWithPrefix.Address { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) + } } else { if err != tcpip.ErrWouldBlock { t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) @@ -312,3 +318,168 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) } } + +// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers +// itself bound to all addresses in the subnet of an assigned address and TCP +// traffic is sent/received correctly. +func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { + const ( + nicID = 1 + localPort = 80 + ) + + ipv4ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8 + ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) + ipv4Bytes[len(ipv4Bytes)-1]++ + otherIPv4Address := tcpip.Address(ipv4Bytes) + + ipv6ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + } + ipv6Bytes := []byte(ipv6Addr.Address) + ipv6Bytes[len(ipv6Bytes)-1]++ + otherIPv6Address := tcpip.Address(ipv6Bytes) + + tests := []struct { + name string + addAddress tcpip.ProtocolAddress + bindAddr tcpip.Address + dstAddr tcpip.Address + expectAccept bool + }{ + { + name: "IPv4 bind to wildcard and send to assigned address", + addAddress: ipv4ProtocolAddress, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + expectAccept: true, + }, + { + name: "IPv4 bind to wildcard and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + dstAddr: otherIPv4Address, + expectAccept: true, + }, + { + name: "IPv4 bind to wildcard send to other address", + addAddress: ipv4ProtocolAddress, + dstAddr: remoteIPv4Addr, + expectAccept: false, + }, + { + name: "IPv4 bind to other subnet-local address and send to assigned address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + expectAccept: false, + }, + { + name: "IPv4 bind and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: otherIPv4Address, + expectAccept: true, + }, + { + name: "IPv4 bind to assigned address and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + dstAddr: otherIPv4Address, + expectAccept: false, + }, + + { + name: "IPv6 bind and send to assigned address", + addAddress: ipv6ProtocolAddress, + bindAddr: ipv6Addr.Address, + dstAddr: ipv6Addr.Address, + expectAccept: true, + }, + { + name: "IPv6 bind to wildcard and send to other subnet-local address", + addAddress: ipv6ProtocolAddress, + dstAddr: otherIPv6Address, + expectAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) + } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer listeningEndpoint.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := listeningEndpoint.Bind(bindAddr); err != nil { + t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err) + } + + if err := listeningEndpoint.Listen(1); err != nil { + t.Fatalf("listeningEndpoint.Listen(1): %s", err) + } + + connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer connectingEndpoint.Close() + + connectAddr := tcpip.FullAddress{ + Addr: test.dstAddr, + Port: localPort, + } + if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + } + + if !test.expectAccept { + if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + return + } + + // Wait for the listening endpoint to be "readable". That is, wait for a + // new connection. + <-ch + var addr tcpip.FullAddress + if _, _, err := listeningEndpoint.Accept(&addr); err != nil { + t.Fatalf("listeningEndpoint.Accept(nil): %s", err) + } + if addr.Addr != test.addAddress.AddressWithPrefix.Address { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 9d30329f5..2e59f6a42 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -96,11 +96,11 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -272,11 +272,11 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: uint16(payloadLen), + TransportProtocol: udp.ProtocolNumber, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -510,10 +510,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) - } - + ep.SocketOptions().SetReuseAddress(true) ep.SocketOptions().SetBroadcast(true) bindAddr := tcpip.FullAddress{Port: localPort} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 440cb0352..74fe19e98 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -49,6 +49,7 @@ const ( // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. @@ -71,11 +72,9 @@ type endpoint struct { // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -85,7 +84,7 @@ type endpoint struct { } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return &endpoint{ + ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, @@ -96,7 +95,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), - }, nil + } + ep.ops.InitHandler(ep) + return ep, nil } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -129,7 +130,10 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } // Update the state. e.state = stateClosed @@ -142,6 +146,7 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } @@ -267,26 +272,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } } - var route *stack.Route - if to == nil { - route = &e.route - - if route.IsResolutionRequired() { - // Promote lock to exclusive if using a shared route, - // given that it may need to change in Route.Resolve() - // call below. - e.mu.RUnlock() - defer e.mu.RLock() - - e.mu.Lock() - defer e.mu.Unlock() - - // Recheck state after lock was re-acquired. - if e.state != stateConnected { - return 0, nil, tcpip.ErrInvalidEndpointState - } - } - } else { + route := e.route + if to != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC @@ -310,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } defer r.Release() - route = &r + route = r } if route.IsResolutionRequired() { @@ -343,26 +330,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.SocketDetachFilterOption: - return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - } - return nil -} - -// SetSockOptBool sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return nil } @@ -378,17 +351,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -426,16 +388,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { @@ -857,6 +810,7 @@ func (*endpoint) LastError() *tcpip.Error { return nil } +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 3bff3755a..9faab4b9e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -60,6 +60,8 @@ type packet struct { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` @@ -83,8 +85,6 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -107,6 +107,7 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, } + ep.ops.InitHandler(ep) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -203,8 +204,8 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha } // Peek implements tcpip.Endpoint.Peek. -func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be @@ -303,26 +304,15 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - ep.mu.Lock() - ep.linger = *v - ep.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption -} - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { @@ -378,26 +368,7 @@ func (ep *endpoint) LastError() *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - ep.mu.Lock() - *o = ep.linger - ep.mu.Unlock() - return nil - - default: - return tcpip.ErrNotSupported - } -} - -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.AcceptConnOption: - return false, nil - default: - return false, tcpip.ErrNotSupported - } + return tcpip.ErrNotSupported } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -551,8 +522,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { return &ep.stats } +// SetOwner implements tcpip.Endpoint.SetOwner. func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 4ae1f92ab..87c60bdab 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -58,12 +58,13 @@ type rawPacket struct { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool - hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -82,10 +83,8 @@ type endpoint struct { bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -114,8 +113,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, - hdrIncluded: !associated, } + e.ops.InitHandler(e) + e.ops.SetHeaderIncluded(!associated) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -170,9 +170,11 @@ func (e *endpoint) Close() { e.rcvList.Remove(e.rcvList.Front()) } - if e.connected { + e.connected = false + + if e.route != nil { e.route.Release() - e.connected = false + e.route = nil } e.closed = true @@ -223,6 +225,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrInvalidOptionValue } + if opts.To != nil { + // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { + return 0, nil, tcpip.ErrInvalidOptionValue + } + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -266,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if e.hdrIncluded { + if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -296,7 +305,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if e.route.IsResolutionRequired() { - savedRoute := &e.route + savedRoute := e.route // Promote lock to exclusive if using a shared route, // given that it may need to change in finishWrite. e.mu.RUnlock() @@ -304,7 +313,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Make sure that the route didn't change during the // time we didn't hold the lock. - if !e.connected || savedRoute != &e.route { + if !e.connected || savedRoute != e.route { e.mu.Unlock() return 0, nil, tcpip.ErrInvalidEndpointState } @@ -314,7 +323,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return n, ch, err } - n, ch, err := e.finishWrite(payloadBytes, &e.route) + n, ch, err := e.finishWrite(payloadBytes, e.route) e.mu.RUnlock() return n, ch, err } @@ -335,7 +344,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - n, ch, err := e.finishWrite(payloadBytes, &route) + n, ch, err := e.finishWrite(payloadBytes, route) route.Release() e.mu.RUnlock() return n, ch, err @@ -356,7 +365,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if e.hdrIncluded { + if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), }) @@ -382,8 +391,8 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Peek implements tcpip.Endpoint.Peek. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -393,6 +402,11 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { + return tcpip.ErrAddressFamilyNotSupported + } + e.mu.Lock() defer e.mu.Unlock() @@ -516,33 +530,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.IPHdrIncludedOption: - e.mu.Lock() - e.hdrIncluded = v - e.mu.Unlock() - return nil - } - return tcpip.ErrUnknownProtocolOption -} - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { @@ -589,33 +585,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - case tcpip.IPHdrIncludedOption: - e.mu.Lock() - v := e.hdrIncluded - e.mu.Unlock() - return v, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -756,10 +726,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} +// LastError implements tcpip.Endpoint.LastError. func (*endpoint) LastError() *tcpip.Error { return nil } +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 7d97cbdc7..4a7e1c039 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) { // If the endpoint is connected, re-connect. if e.connected { var err *tcpip.Error - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false) + // TODO(gvisor.dev/issue/4906): Properly restore the route with the right + // remote address. We used to pass e.remote.RemoteAddress which was + // effectively the empty address but since moving e.route to hold a pointer + // to a route instead of the route by value, we pass the empty address + // directly. Obviously this was always wrong since we should provide the + // remote address we were connected to, to properly restore the route. + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 518449602..cf232b508 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "more_shards") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -45,7 +45,9 @@ go_library( "rcv.go", "rcv_state.go", "reno.go", + "reno_recovery.go", "sack.go", + "sack_recovery.go", "sack_scoreboard.go", "segment.go", "segment_heap.go", @@ -91,7 +93,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], - shard_count = 10, + shard_count = more_shards, deps = [ ":tcp", "//pkg/rand", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6e5adc383..3e1041cbe 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -213,7 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i route.ResolveWith(s.remoteLinkAddr) n := newEndpoint(l.stack, netProto, queue) - n.v6only = l.v6Only + n.ops.SetV6Only(l.v6Only) n.ID = s.id n.boundNICID = s.nicID n.route = route @@ -599,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er ack: s.sequenceNumber + 1, rcvWnd: ctx.rcvWnd, } - if err := e.sendSynTCP(&route, fields, synOpts); err != nil { + if err := e.sendSynTCP(route, fields, synOpts); err != nil { return err } e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() @@ -752,7 +752,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() - v6Only := e.v6only + v6Only := e.ops.GetV6Only() ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) defer func() { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index ac6d879a7..c944dccc0 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -16,6 +16,7 @@ package tcp import ( "encoding/binary" + "math" "time" "gvisor.dev/gvisor/pkg/rand" @@ -133,7 +134,7 @@ func FindWndScale(wnd seqnum.Size) int { return 0 } - max := seqnum.Size(0xffff) + max := seqnum.Size(math.MaxUint16) s := 0 for wnd > max && s < header.MaxWndScale { s++ @@ -300,7 +301,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if ttl == 0 { ttl = h.ep.route.DefaultTTL() } - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: ttl, tos: h.ep.sendTOS, @@ -361,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -496,7 +497,7 @@ func (h *handshake) resolveRoute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.LastError() + return h.ep.lastErrorLocked() } } @@ -547,7 +548,7 @@ func (h *handshake) start() *tcpip.Error { } h.sendSYNOpts = synOpts - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -575,7 +576,6 @@ func (h *handshake) complete() *tcpip.Error { return err } defer timer.stop() - for h.state != handshakeCompleted { // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held // throughout handshake processing). @@ -597,7 +597,7 @@ func (h *handshake) complete() *tcpip.Error { // the connection with another ACK or data (as ACKs are never // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -631,9 +631,8 @@ func (h *handshake) complete() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.LastError() + return h.ep.lastErrorLocked() } - case wakerForNewSegment: if err := h.processSegments(); err != nil { return err @@ -820,8 +819,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso data = data.Clone(nil) optLen := len(tf.opts) - if tf.rcvWnd > 0xffff { - tf.rcvWnd = 0xffff + if tf.rcvWnd > math.MaxUint16 { + tf.rcvWnd = math.MaxUint16 } mss := int(gso.MSS) @@ -865,8 +864,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // network endpoint and under the provided identity. func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) - if tf.rcvWnd > 0xffff { - tf.rcvWnd = 0xffff + if tf.rcvWnd > math.MaxUint16 { + tf.rcvWnd = math.MaxUint16 } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { @@ -941,7 +940,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := e.sendTCP(&e.route, tcpFields{ + err := e.sendTCP(e.route, tcpFields{ id: e.ID, ttl: e.ttl, tos: e.sendTOS, @@ -1002,7 +1001,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.setEndpointState(StateError) - e.HardError = err + e.hardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk @@ -1080,7 +1079,7 @@ func (e *endpoint) transitionToStateCloseLocked() { // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) - if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { + if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) } @@ -1141,7 +1140,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // delete the TCB, and return. case StateCloseWait: e.transitionToStateCloseLocked() - e.HardError = tcpip.ErrAborted + e.hardError = tcpip.ErrAborted e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: @@ -1286,7 +1285,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { userTimeout := e.userTimeout e.keepalive.Lock() - if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { + if !e.SocketOptions().GetKeepAlive() || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } @@ -1323,7 +1322,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } // Start the keepalive timer IFF it's enabled and there is no pending // data to send. - if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() e.keepalive.Unlock() return @@ -1353,7 +1352,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ epilogue := func() { // e.mu is expected to be hold upon entering this section. - if e.snd != nil { e.snd.resendTimer.cleanup() } @@ -1383,7 +1381,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.lastErrorMu.Unlock() e.setEndpointState(StateError) - e.HardError = err + e.hardError = err e.workerCleanup = true // Lock released below. @@ -1638,7 +1636,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() } extTW, newSyn := e.rcv.handleTimeWaitSegment(s) if newSyn { - info := e.EndpointInfo.TransportEndpointInfo + info := e.TransportEndpointInfo newID := info.ID newID.RemoteAddress = "" newID.RemotePort = 0 diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index a6f25896b..1d1b01a6c 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -405,14 +405,6 @@ func testV4Accept(t *testing.T, c *context.Context) { } } - // Make sure we get the same error when calling the original ep and the - // new one. This validates that v4-mapped endpoints are still able to - // query the V6Only flag, whereas pure v4 endpoints are not. - _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) - if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { - t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) - } - // Check the peer address. addr, err := nep.GetRemoteAddress() if err != nil { @@ -530,12 +522,12 @@ func TestV6AcceptOnV6(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress - nep, _, err := c.EP.Accept(&addr) + _, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept(&addr) + _, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -548,12 +540,6 @@ func TestV6AcceptOnV6(t *testing.T) { if addr.Addr != context.TestV6Addr { t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) } - - // Make sure we can still query the v6 only status of the new endpoint, - // that is, that it is in fact a v6 socket. - if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) - } } func TestV4AcceptOnV4(t *testing.T) { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4f4f4c65e..bb0795f78 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -310,16 +310,12 @@ type Stats struct { func (*Stats) IsEndpointStats() {} // EndpointInfo holds useful information about a transport endpoint which -// can be queried by monitoring tools. +// can be queried by monitoring tools. This exists to allow tcp-only state to +// be exposed. // // +stateify savable type EndpointInfo struct { stack.TransportEndpointInfo - - // HardError is meaningful only when state is stateError. It stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. HardError is protected by endpoint mu. - HardError *tcpip.Error `state:".(string)"` } // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -367,6 +363,7 @@ func (*EndpointInfo) IsEndpointInfo() {} // +stateify savable type endpoint struct { EndpointInfo + tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the // a given tcp processor goroutine. @@ -386,6 +383,11 @@ type endpoint struct { waiterQueue *waiter.Queue `state:"wait"` uniqueID uint64 + // hardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. hardError is protected by endpoint mu. + hardError *tcpip.Error `state:".(string)"` + // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. lastErrorMu sync.Mutex `state:"nosave"` @@ -421,7 +423,10 @@ type endpoint struct { // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. - mu sync.Mutex `state:"nosave"` + // + // During handshake, mu is locked by the protocol listen goroutine and + // released by the handshake completion goroutine. + mu sync.CrossGoroutineMutex `state:"nosave"` ownedByUser uint32 // state must be read/set using the EndpointState()/setEndpointState() @@ -436,9 +441,8 @@ type endpoint struct { isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` boundNICID tcpip.NICID - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` ttl uint8 - v6only bool isConnectNotified bool // h stores a reference to the current handshake state if the endpoint is in @@ -506,24 +510,9 @@ type endpoint struct { // delay is a boolean (0 is false) and must be accessed atomically. delay uint32 - // cork holds back segments until full. - // - // cork is a boolean (0 is false) and must be accessed atomically. - cork uint32 - // scoreboard holds TCP SACK Scoreboard information for this endpoint. scoreboard *SACKScoreboard - // The options below aren't implemented, but we remember the user - // settings because applications expect to be able to set/query these - // options. - - // slowAck holds the negated state of quick ack. It is stubbed out and - // does nothing. - // - // slowAck is a boolean (0 is false) and must be accessed atomically. - slowAck uint32 - // segmentQueue is used to hand received segments to the protocol // goroutine. Segments are queued as long as the queue is not full, // and dropped when it is. @@ -685,9 +674,6 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -701,7 +687,7 @@ func (e *endpoint) UniqueID() uint64 { // // If userMSS is non-zero and is not greater than the maximum possible MSS for // r, it will be used; otherwise, the maximum possible MSS will be used. -func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { +func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 { // The maximum possible MSS is dependent on the route. // TODO(b/143359391): Respect TCP Min and Max size. maxMSS := uint16(r.MTU() - header.TCPMinimumSize) @@ -850,7 +836,6 @@ func (e *endpoint) recentTimestamp() uint32 { // +stateify savable type keepalive struct { sync.Mutex `state:"nosave"` - enabled bool idle time.Duration interval time.Duration count int @@ -884,6 +869,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } + e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) + e.ops.SetQuickAck(true) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -907,7 +895,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var de tcpip.TCPDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { - e.SetSockOptBool(tcpip.DelayOption, true) + e.ops.SetDelayOption(true) } var tcpLT tcpip.TCPLingerTimeoutOption @@ -1049,7 +1037,8 @@ func (e *endpoint) Close() { return } - if e.linger.Enabled && e.linger.Timeout == 0 { + linger := e.SocketOptions().GetLinger() + if linger.Enabled && linger.Timeout == 0 { s := e.EndpointState() isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv if isResetState { @@ -1169,7 +1158,11 @@ func (e *endpoint) cleanupLocked() { e.boundPortFlags = ports.Flags{} e.boundDest = tcpip.FullAddress{} - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -1279,11 +1272,20 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -func (e *endpoint) LastError() *tcpip.Error { +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) hardErrorLocked() *tcpip.Error { + err := e.hardError + e.hardError = nil + return err +} + +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) lastErrorLocked() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1291,6 +1293,16 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// LastError implements tcpip.Endpoint.LastError. +func (e *endpoint) LastError() *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + if err := e.hardErrorLocked(); err != nil { + return err + } + return e.lastErrorLocked() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1312,9 +1324,11 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.HardError if s == StateError { - return buffer.View{}, tcpip.ControlMessages{}, he + if err := e.hardErrorLocked(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } e.stats.ReadErrors.NotConnected.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected @@ -1370,9 +1384,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { + // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: - return 0, e.HardError + if err := e.hardErrorLocked(); err != nil { + return 0, err + } + return 0, tcpip.ErrClosedForSend case !s.connecting() && !s.connected(): return 0, tcpip.ErrClosedForSend case s.connecting(): @@ -1478,7 +1496,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1486,10 +1504,10 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.HardError + return 0, e.hardErrorLocked() } e.stats.ReadErrors.InvalidEndpointState.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + return 0, tcpip.ErrInvalidEndpointState } e.rcvListMu.Lock() @@ -1498,9 +1516,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + return 0, tcpip.ErrClosedForReceive } - return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + return 0, tcpip.ErrWouldBlock } // Make a copy of vec so we can modify the slide headers. @@ -1515,7 +1533,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro for len(v) > 0 { if len(vec) == 0 { - return num, tcpip.ControlMessages{}, nil + return num, nil } if len(vec[0]) == 0 { vec = vec[1:] @@ -1530,7 +1548,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } } - return num, tcpip.ControlMessages{}, nil + return num, nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1602,72 +1620,39 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo return false, false } -// SetSockOptBool sets a socket option. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - - case tcpip.CorkOption: - e.LockUser() - if !v { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - e.UnlockUser() - - case tcpip.DelayOption: - if v { - atomic.StoreUint32(&e.delay, 1) - } else { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - - case tcpip.QuickAckOption: - o := uint32(1) - if v { - o = 0 - } - atomic.StoreUint32(&e.slowAck, o) - - case tcpip.ReuseAddressOption: - e.LockUser() - e.portFlags.TupleOnly = v - e.UnlockUser() - - case tcpip.ReusePortOption: - e.LockUser() - e.portFlags.LoadBalanced = v - e.UnlockUser() +// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +func (e *endpoint) OnReuseAddressSet(v bool) { + e.LockUser() + e.portFlags.TupleOnly = v + e.UnlockUser() +} - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } +// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +func (e *endpoint) OnReusePortSet(v bool) { + e.LockUser() + e.portFlags.LoadBalanced = v + e.UnlockUser() +} - // We only allow this to be set when we're in the initial state. - if e.EndpointState() != StateInitial { - return tcpip.ErrInvalidEndpointState - } +// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet. +func (e *endpoint) OnKeepAliveSet(v bool) { + e.notifyProtocolGoroutine(notifyKeepaliveChanged) +} - e.LockUser() - e.v6only = v - e.UnlockUser() +// OnDelayOptionSet implements tcpip.SocketOptionsHandler.OnDelayOptionSet. +func (e *endpoint) OnDelayOptionSet(v bool) { + if !v { + // Handle delayed data. + e.sndWaker.Assert() } +} - return nil +// OnCorkOptionSet implements tcpip.SocketOptionsHandler.OnCorkOptionSet. +func (e *endpoint) OnCorkOptionSet(v bool) { + if !v { + // Handle the corked data. + e.sndWaker.Assert() + } } // SetSockOptInt sets a socket option. @@ -1851,9 +1836,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - case *tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(*v) @@ -1922,11 +1904,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.LockUser() - e.linger = *v - e.UnlockUser() - default: return nil } @@ -1949,67 +1926,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - - case tcpip.CorkOption: - return atomic.LoadUint32(&e.cork) != 0, nil - - case tcpip.DelayOption: - return atomic.LoadUint32(&e.delay) != 0, nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - return v, nil - - case tcpip.QuickAckOption: - v := atomic.LoadUint32(&e.slowAck) == 0 - return v, nil - - case tcpip.ReuseAddressOption: - e.LockUser() - v := e.portFlags.TupleOnly - e.UnlockUser() - - return v, nil - - case tcpip.ReusePortOption: - e.LockUser() - v := e.portFlags.LoadBalanced - e.UnlockUser() - - return v, nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.LockUser() - v := e.v6only - e.UnlockUser() - - return v, nil - - case tcpip.MulticastLoopOption: - return true, nil - - case tcpip.AcceptConnOption: - e.LockUser() - defer e.UnlockUser() - - return e.EndpointState() == StateListen, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -2120,10 +2036,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - *o = 1 - case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc @@ -2152,11 +2064,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { Port: port, } - case *tcpip.LingerOption: - e.LockUser() - *o = e.linger - e.UnlockUser() - default: return tcpip.ErrUnknownProtocolOption } @@ -2166,7 +2073,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -2243,7 +2150,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnecting case StateError: - return e.HardError + if err := e.hardErrorLocked(); err != nil { + return err + } + return tcpip.ErrConnectionAborted default: return tcpip.ErrInvalidEndpointState @@ -2417,7 +2327,7 @@ func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.setEndpointState(StateError) - e.HardError = err + e.hardError = err // Call cleanupLocked to free up any reservations. e.cleanupLocked() @@ -2697,7 +2607,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // v6only set to false. if netProto == header.IPv6ProtocolNumber { stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) - alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4 + alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4 if alsoBindToV4 { netProtos = append(netProtos, header.IPv4ProtocolNumber) } @@ -2782,7 +2692,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first - // land at the Dispatcher which then can either delivery using the + // land at the Dispatcher which then can either deliver using the // worker go routine or directly do the invoke the tcp processing inline // based on the state of the endpoint. } @@ -3079,6 +2989,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { Ssthresh: e.snd.sndSsthresh, SndCAAckCount: e.snd.sndCAAckCount, Outstanding: e.snd.outstanding, + SackedOut: e.snd.sackedOut, SndWnd: e.snd.sndWnd, SndUna: e.snd.sndUna, SndNxt: e.snd.sndNxt, @@ -3161,7 +3072,7 @@ func (e *endpoint) State() uint32 { func (e *endpoint) Info() tcpip.EndpointInfo { e.LockUser() // Make a copy of the endpoint info. - ret := e.EndpointInfo + ret := e.TransportEndpointInfo e.UnlockUser() return &ret } @@ -3187,6 +3098,7 @@ func (e *endpoint) Wait() { } } +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index bb901c0f8..ba67176b5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -321,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { } // saveHardError is invoked by stateify. -func (e *EndpointInfo) saveHardError() string { - if e.HardError == nil { +func (e *endpoint) saveHardError() string { + if e.hardError == nil { return "" } - return e.HardError.String() + return e.hardError.String() } // loadHardError is invoked by stateify. -func (e *EndpointInfo) loadHardError(s string) { +func (e *endpoint) loadHardError(s string) { if s == "" { return } - e.HardError = tcpip.StringToError(s) + e.hardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2329aca4b..672159eed 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error ttl = route.DefaultTTL() } - return sendTCP(&route, tcpFields{ + return sendTCP(route, tcpFields{ id: s.id, ttl: ttl, tos: tos, diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 8e0b7c843..405a6dce7 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -16,6 +16,7 @@ package tcp import ( "container/heap" + "math" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -48,6 +49,10 @@ type receiver struct { rcvWndScale uint8 + // prevBufused is the snapshot of endpoint rcvBufUsed taken when we + // advertise a receive window. + prevBufUsed int + closed bool // pendingRcvdSegments is bounded by the receive buffer size of the @@ -80,9 +85,9 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // outgoing packets, we should use what we have advertised for acceptability // test. scaledWindowSize := r.rcvWnd >> r.rcvWndScale - if scaledWindowSize > 0xffff { + if scaledWindowSize > math.MaxUint16 { // This is what we actually put in the Window field. - scaledWindowSize = 0xffff + scaledWindowSize = math.MaxUint16 } advertisedWindowSize := scaledWindowSize << r.rcvWndScale return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) @@ -106,6 +111,34 @@ func (r *receiver) currentWindow() (curWnd seqnum.Size) { func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { newWnd := r.ep.selectWindow() curWnd := r.currentWindow() + unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt)) + bufUsed := r.ep.receiveBufferUsed() + + // Grow the right edge of the window only for payloads larger than the + // the segment overhead OR if the application is actively consuming data. + // + // Avoiding growing the right edge otherwise, addresses a situation below: + // An application has been slow in reading data and we have burst of + // incoming segments lengths < segment overhead. Here, our available free + // memory would reduce drastically when compared to the advertised receive + // window. + // + // For example: With incoming 512 bytes segments, segment overhead of + // 552 bytes (at the time of writing this comment), with receive window + // starting from 1MB and with rcvAdvWndScale being 1, buffer would reach 0 + // when the curWnd is still 19436 bytes, because for every incoming segment + // newWnd would reduce by (552+512) >> rcvAdvWndScale (current value 1), + // while curWnd would reduce by 512 bytes. + // Such a situation causes us to keep tail dropping the incoming segments + // and never advertise zero receive window to the peer. + // + // Linux does a similar check for minimal sk_buff size (128): + // https://github.com/torvalds/linux/blob/d5beb3140f91b1c8a3d41b14d729aefa4dcc58bc/net/ipv4/tcp_input.c#L783 + // + // Also, if the application is reading the data, we keep growing the right + // edge, as we are still advertising a window that we think can be serviced. + toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed + // Update rcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we @@ -115,7 +148,7 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP rcvNxt rcvAcc new rcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) { + if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow { // If the new window moves the right edge, then update rcvAcc. r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) } else { @@ -130,11 +163,22 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // receiver's estimated RTT. r.rcvWnd = newWnd r.rcvWUP = r.rcvNxt + r.prevBufUsed = bufUsed scaledWnd := r.rcvWnd >> r.rcvWndScale if scaledWnd == 0 { // Increment a metric if we are advertising an actual zero window. r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() } + + // If we started off with a window larger than what can he held in + // the 16bit window field, we ceil the value to the max value. + if scaledWnd > math.MaxUint16 { + scaledWnd = seqnum.Size(math.MaxUint16) + + // Ensure that the stashed receive window always reflects what + // is being advertised. + r.rcvWnd = scaledWnd << r.rcvWndScale + } return r.rcvNxt, scaledWnd } diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go new file mode 100644 index 000000000..2aa708e97 --- /dev/null +++ b/pkg/tcpip/transport/tcp/reno_recovery.go @@ -0,0 +1,67 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +// renoRecovery stores the variables related to TCP Reno loss recovery +// algorithm. +// +// +stateify savable +type renoRecovery struct { + s *sender +} + +func newRenoRecovery(s *sender) *renoRecovery { + return &renoRecovery{s: s} +} + +func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { + ack := rcvdSeg.ackNumber + snd := rr.s + + // We are in fast recovery mode. Ignore the ack if it's out of range. + if !ack.InRange(snd.sndUna, snd.sndNxt+1) { + return + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window { + return + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if !fastRetransmit && ack == snd.fr.first { + // We received a dup, inflate the congestion window by 1 packet + // if we're not at the max yet. Only inflate the window if + // regular FastRecovery is in use, RFC6675 does not require + // inflating cwnd on duplicate ACKs. + if snd.sndCwnd < snd.fr.maxCwnd { + snd.sndCwnd++ + } + return + } + + // A partial ack was received. Retransmit this packet and remember it + // so that we don't retransmit it again. + // + // We don't inflate the window because we're putting the same packet + // back onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + snd.fr.first = ack + snd.dupAckCount = 0 + snd.resendSegment() +} diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go new file mode 100644 index 000000000..7e813fa96 --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack_recovery.go @@ -0,0 +1,120 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import "gvisor.dev/gvisor/pkg/tcpip/seqnum" + +// sackRecovery stores the variables related to TCP SACK loss recovery +// algorithm. +// +// +stateify savable +type sackRecovery struct { + s *sender +} + +func newSACKRecovery(s *sender) *sackRecovery { + return &sackRecovery{s: s} +} + +// handleSACKRecovery implements the loss recovery phase as described in RFC6675 +// section 5, step C. +func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { + snd := sr.s + snd.SetPipe() + + if smss := int(snd.ep.scoreboard.SMSS()); limit > smss { + // Cap segment size limit to s.smss as SACK recovery requires + // that all retransmissions or new segments send during recovery + // be of <= SMSS. + limit = smss + } + + nextSegHint := snd.writeList.Front() + for snd.outstanding < snd.sndCwnd { + var nextSeg *segment + var rescueRtx bool + nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint) + if nextSeg == nil { + return dataSent + } + if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + // New data being sent. + + // Step C.3 described below is handled by + // maybeSendSegment which increments sndNxt when + // a segment is transmitted. + // + // Step C.3 "If any of the data octets sent in + // (C.1) are above HighData, HighData must be + // updated to reflect the transmission of + // previously unsent data." + // + // We pass s.smss as the limit as the Step 2) requires that + // new data sent should be of size s.smss or less. + if sent := snd.maybeSendSegment(nextSeg, limit, end); !sent { + return dataSent + } + dataSent = true + snd.outstanding++ + snd.writeNext = nextSeg.Next() + continue + } + + // Now handle the retransmission case where we matched either step 1,3 or 4 + // of the NextSeg algorithm. + // RFC 6675, Step C.4. + // + // "The estimate of the amount of data outstanding in the network + // must be updated by incrementing pipe by the number of octets + // transmitted in (C.1)." + snd.outstanding++ + dataSent = true + snd.sendSegment(nextSeg) + + segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) + if rescueRtx { + // We do the last part of rule (4) of NextSeg here to update + // RescueRxt as until this point we don't know if we are going + // to use the rescue transmission. + snd.fr.rescueRxt = snd.fr.last + } else { + // RFC 6675, Step C.2 + // + // "If any of the data octets sent in (C.1) are below + // HighData, HighRxt MUST be set to the highest sequence + // number of the retransmitted segment unless NextSeg () + // rule (4) was invoked for this retransmission." + snd.fr.highRxt = segEnd - 1 + } + } + return dataSent +} + +func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { + snd := sr.s + if fastRetransmit { + snd.resendSegment() + } + + // We are in fast recovery mode. Ignore the ack if it's out of range. + if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) { + return + } + + // RFC 6675 recovery algorithm step C 1-5. + end := snd.sndUna.Add(snd.sndWnd) + dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end) + snd.postXmit(dataSent) +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 2091989cc..5ef73ec74 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -204,7 +204,7 @@ func (s *segment) payloadSize() int { // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { - return segSize + s.data.Size() + return SegSize + s.data.Size() } // parse populates the sequence & ack numbers, flags, and window fields of the diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go index 0ab7b8f56..392ff0859 100644 --- a/pkg/tcpip/transport/tcp/segment_unsafe.go +++ b/pkg/tcpip/transport/tcp/segment_unsafe.go @@ -19,5 +19,6 @@ import ( ) const ( - segSize = int(unsafe.Sizeof(segment{})) + // SegSize is the minimal size of the segment overhead. + SegSize = int(unsafe.Sizeof(segment{})) ) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 0e0fdf14c..cc991aba6 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -18,7 +18,6 @@ import ( "fmt" "math" "sort" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -92,6 +91,17 @@ type congestionControl interface { PostRecovery() } +// lossRecovery is an interface that must be implemented by any supported +// loss recovery algorithm. +type lossRecovery interface { + // DoRecovery is invoked when loss is detected and segments need + // to be retransmitted. The cumulative or selective ACK is passed along + // with the flag which identifies whether the connection entered fast + // retransmit with this ACK and to retransmit the first unacknowledged + // segment. + DoRecovery(rcvdSeg *segment, fastRetransmit bool) +} + // sender holds the state necessary to send TCP segments. // // +stateify savable @@ -108,6 +118,9 @@ type sender struct { // fr holds state related to fast recovery. fr fastRecovery + // lr is the loss recovery algorithm used by the sender. + lr lossRecovery + // sndCwnd is the congestion window, in packets. sndCwnd int @@ -124,6 +137,9 @@ type sender struct { // that have been sent but not yet acknowledged. outstanding int + // sackedOut is the number of packets which are selectively acked. + sackedOut int + // sndWnd is the send window size. sndWnd seqnum.Size @@ -276,6 +292,8 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.cc = s.initCongestionControl(ep.cc) + s.lr = s.initLossRecovery() + // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. if sndWndScale > 0 { @@ -330,6 +348,14 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon } } +// initLossRecovery initiates the loss recovery algorithm for the sender. +func (s *sender) initLossRecovery() lossRecovery { + if s.ep.sackPermitted { + return newSACKRecovery(s) + } + return newRenoRecovery(s) +} + // updateMaxPayloadSize updates the maximum payload size based on the given // MTU. If this is in response to "packet too big" control packets (indicated // by the count argument), it also reduces the number of outstanding packets and @@ -349,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } + oldMSS := s.maxPayloadSize s.maxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) @@ -371,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // Rewind writeNext to the first segment exceeding the MTU. Do nothing // if it is already before such a packet. + nextSeg := s.writeNext for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { if seg == s.writeNext { // We got to writeNext before we could find a segment @@ -378,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { break } - if seg.data.Size() > m { + if nextSeg == s.writeNext && seg.data.Size() > m { // We found a segment exceeding the MTU. Rewind // writeNext and try to retransmit it. - s.writeNext = seg - break + nextSeg = seg + } + + if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Update sackedOut for new maximum payload size. + s.sackedOut -= s.pCount(seg, oldMSS) + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } } // Since we likely reduced the number of outstanding packets, we may be // ready to send some more. + s.writeNext = nextSeg s.sendData() } @@ -550,7 +584,7 @@ func (s *sender) retransmitTimerExpired() bool { // We were attempting fast recovery but were not successful. // Leave the state. We don't need to update ssthresh because it // has already been updated when entered fast-recovery. - s.leaveFastRecovery() + s.leaveRecovery() } s.state = RTORecovery @@ -606,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool { // pCount returns the number of packets in the segment. Due to GSO, a segment // can be composed of multiple packets. -func (s *sender) pCount(seg *segment) int { +func (s *sender) pCount(seg *segment, maxPayloadSize int) int { size := seg.data.Size() if size == 0 { return 1 } - return (size-1)/s.maxPayloadSize + 1 + return (size-1)/maxPayloadSize + 1 } // splitSeg splits a given segment at the size specified and inserts the @@ -789,7 +823,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } if !nextTooBig && seg.data.Size() < available { // Segment is not full. - if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 { + if s.outstanding > 0 && s.ep.ops.GetDelayOption() { // Nagle's algorithm. From Wikipedia: // Nagle's algorithm works by // combining a number of small @@ -808,7 +842,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // send space and MSS. // TODO(gvisor.dev/issue/2833): Drain the held segments after a // timeout. - if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { + if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() { return false } } @@ -913,79 +947,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se return true } -// handleSACKRecovery implements the loss recovery phase as described in RFC6675 -// section 5, step C. -func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { - s.SetPipe() - - if smss := int(s.ep.scoreboard.SMSS()); limit > smss { - // Cap segment size limit to s.smss as SACK recovery requires - // that all retransmissions or new segments send during recovery - // be of <= SMSS. - limit = smss - } - - nextSegHint := s.writeList.Front() - for s.outstanding < s.sndCwnd { - var nextSeg *segment - var rescueRtx bool - nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint) - if nextSeg == nil { - return dataSent - } - if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) { - // New data being sent. - - // Step C.3 described below is handled by - // maybeSendSegment which increments sndNxt when - // a segment is transmitted. - // - // Step C.3 "If any of the data octets sent in - // (C.1) are above HighData, HighData must be - // updated to reflect the transmission of - // previously unsent data." - // - // We pass s.smss as the limit as the Step 2) requires that - // new data sent should be of size s.smss or less. - if sent := s.maybeSendSegment(nextSeg, limit, end); !sent { - return dataSent - } - dataSent = true - s.outstanding++ - s.writeNext = nextSeg.Next() - continue - } - - // Now handle the retransmission case where we matched either step 1,3 or 4 - // of the NextSeg algorithm. - // RFC 6675, Step C.4. - // - // "The estimate of the amount of data outstanding in the network - // must be updated by incrementing pipe by the number of octets - // transmitted in (C.1)." - s.outstanding++ - dataSent = true - s.sendSegment(nextSeg) - - segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) - if rescueRtx { - // We do the last part of rule (4) of NextSeg here to update - // RescueRxt as until this point we don't know if we are going - // to use the rescue transmission. - s.fr.rescueRxt = s.fr.last - } else { - // RFC 6675, Step C.2 - // - // "If any of the data octets sent in (C.1) are below - // HighData, HighRxt MUST be set to the highest sequence - // number of the retransmitted segment unless NextSeg () - // rule (4) was invoked for this retransmission." - s.fr.highRxt = segEnd - 1 - } - } - return dataSent -} - func (s *sender) sendZeroWindowProbe() { ack, win := s.ep.rcv.getSendParams() s.unackZeroWindowProbes++ @@ -1014,6 +975,30 @@ func (s *sender) disableZeroWindowProbing() { s.resendTimer.disable() } +func (s *sender) postXmit(dataSent bool) { + if dataSent { + // We sent data, so we should stop the keepalive timer to ensure + // that no keepalives are sent while there is pending data. + s.ep.disableKeepaliveTimer() + } + + // If the sender has advertized zero receive window and we have + // data to be sent out, start zero window probing to query the + // the remote for it's receive window size. + if s.writeNext != nil && s.sndWnd == 0 { + s.enableZeroWindowProbing() + } + + // Enable the timer if we have pending data and it's not enabled yet. + if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { + s.resendTimer.enable(s.rto) + } + // If we have no more pending data, start the keepalive timer. + if s.sndUna == s.sndNxt { + s.ep.resetKeepaliveTimer(false) + } +} + // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { @@ -1034,55 +1019,29 @@ func (s *sender) sendData() { } var dataSent bool - - // RFC 6675 recovery algorithm step C 1-5. - if s.fr.active && s.ep.sackPermitted { - dataSent = s.handleSACKRecovery(s.maxPayloadSize, end) - } else { - for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { - cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize - if cwndLimit < limit { - limit = cwndLimit - } - if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - // Move writeNext along so that we don't try and scan data that - // has already been SACKED. - s.writeNext = seg.Next() - continue - } - if sent := s.maybeSendSegment(seg, limit, end); !sent { - break - } - dataSent = true - s.outstanding += s.pCount(seg) + for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + if cwndLimit < limit { + limit = cwndLimit + } + if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Move writeNext along so that we don't try and scan data that + // has already been SACKED. s.writeNext = seg.Next() + continue } + if sent := s.maybeSendSegment(seg, limit, end); !sent { + break + } + dataSent = true + s.outstanding += s.pCount(seg, s.maxPayloadSize) + s.writeNext = seg.Next() } - if dataSent { - // We sent data, so we should stop the keepalive timer to ensure - // that no keepalives are sent while there is pending data. - s.ep.disableKeepaliveTimer() - } - - // If the sender has advertized zero receive window and we have - // data to be sent out, start zero window probing to query the - // the remote for it's receive window size. - if s.writeNext != nil && s.sndWnd == 0 { - s.enableZeroWindowProbing() - } - - // Enable the timer if we have pending data and it's not enabled yet. - if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { - s.resendTimer.enable(s.rto) - } - // If we have no more pending data, start the keepalive timer. - if s.sndUna == s.sndNxt { - s.ep.resetKeepaliveTimer(false) - } + s.postXmit(dataSent) } -func (s *sender) enterFastRecovery() { +func (s *sender) enterRecovery() { s.fr.active = true // Save state to reflect we're now in fast recovery. // @@ -1090,6 +1049,7 @@ func (s *sender) enterFastRecovery() { // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. s.sndCwnd = s.sndSsthresh + 3 + s.sackedOut = 0 s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding @@ -1104,7 +1064,7 @@ func (s *sender) enterFastRecovery() { s.ep.stack.Stats().TCP.FastRecovery.Increment() } -func (s *sender) leaveFastRecovery() { +func (s *sender) leaveRecovery() { s.fr.active = false s.fr.maxCwnd = 0 s.dupAckCount = 0 @@ -1115,57 +1075,6 @@ func (s *sender) leaveFastRecovery() { s.cc.PostRecovery() } -func (s *sender) handleFastRecovery(seg *segment) (rtx bool) { - ack := seg.ackNumber - // We are in fast recovery mode. Ignore the ack if it's out of - // range. - if !ack.InRange(s.sndUna, s.sndNxt+1) { - return false - } - - // Leave fast recovery if it acknowledges all the data covered by - // this fast recovery session. - if s.fr.last.LessThan(ack) { - s.leaveFastRecovery() - return false - } - - if s.ep.sackPermitted { - // When SACK is enabled we let retransmission be governed by - // the SACK logic. - return false - } - - // Don't count this as a duplicate if it is carrying data or - // updating the window. - if seg.logicalLen() != 0 || s.sndWnd != seg.window { - return false - } - - // Inflate the congestion window if we're getting duplicate acks - // for the packet we retransmitted. - if ack == s.fr.first { - // We received a dup, inflate the congestion window by 1 packet - // if we're not at the max yet. Only inflate the window if - // regular FastRecovery is in use, RFC6675 does not require - // inflating cwnd on duplicate ACKs. - if s.sndCwnd < s.fr.maxCwnd { - s.sndCwnd++ - } - return false - } - - // A partial ack was received. Retransmit this packet and - // remember it so that we don't retransmit it again. We don't - // inflate the window because we're putting the same packet back - // onto the wire. - // - // N.B. The retransmit timer will be reset by the caller. - s.fr.first = ack - s.dupAckCount = 0 - return true -} - // isAssignedSequenceNumber relies on the fact that we only set flags once a // sequencenumber is assigned and that is only done right before we send the // segment. As a result any segment that has a non-zero flag has a valid @@ -1228,14 +1137,11 @@ func (s *sender) SetPipe() { s.outstanding = pipe } -// checkDuplicateAck is called when an ack is received. It manages the state -// related to duplicate acks and determines if a retransmit is needed according -// to the rules in RFC 6582 (NewReno). -func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { +// detectLoss is called when an ack is received and returns whether a loss is +// detected. It manages the state related to duplicate acks and determines if +// a retransmit is needed according to the rules in RFC 6582 (NewReno). +func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { ack := seg.ackNumber - if s.fr.active { - return s.handleFastRecovery(seg) - } // We're not in fast recovery yet. A segment is considered a duplicate // only if it doesn't carry any data and doesn't update the send window, @@ -1266,14 +1172,14 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2 // // We only do the check here, the incrementing of last to the highest - // sequence number transmitted till now is done when enterFastRecovery + // sequence number transmitted till now is done when enterRecovery // is invoked. if !s.fr.last.LessThan(seg.ackNumber) { s.dupAckCount = 0 return false } s.cc.HandleNDupAcks() - s.enterFastRecovery() + s.enterRecovery() s.dupAckCount = 0 return true } @@ -1313,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.update(seg, rcvdSeg, s.ep.tsOffset) s.rc.detectReorder(seg) seg.acked = true + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } seg = seg.Next() } @@ -1415,14 +1322,23 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.SetPipe() } - // Count the duplicates and do the fast retransmit if needed. - rtx := s.checkDuplicateAck(rcvdSeg) + ack := rcvdSeg.ackNumber + fastRetransmit := false + // Do not leave fast recovery, if the ACK is out of range. + if s.fr.active { + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + if ack.InRange(s.sndUna, s.sndNxt+1) && s.fr.last.LessThan(ack) { + s.leaveRecovery() + } + } else { + // Detect loss by counting the duplicates and enter recovery. + fastRetransmit = s.detectLoss(rcvdSeg) + } // Stash away the current window size. s.sndWnd = rcvdSeg.window - ack := rcvdSeg.ackNumber - // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously @@ -1477,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg) + prevCount := s.pCount(seg, s.maxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg) + s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) break } @@ -1496,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.writeList.Remove(seg) - // If SACK is enabled then Only reduce outstanding if + // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg) + s.outstanding -= s.pCount(seg, s.maxPayloadSize) + } else { + s.sackedOut -= s.pCount(seg, s.maxPayloadSize) } seg.decRef() ackLeft -= datalen @@ -1539,19 +1457,24 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.resendTimer.disable() } } + // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. - if rtx { - s.resendSegment() + if s.fr.active { + s.lr.DoRecovery(rcvdSeg, fastRetransmit) + // When SACK is enabled data sending is governed by steps in + // RFC 6675 Section 5 recovery steps A-C. + // See: https://tools.ietf.org/html/rfc6675#section-5. + if s.ep.sackPermitted { + return + } } // Send more data now that some of the pending data has been ack'd, or // that the window opened up, or the congestion window was inflated due // to a duplicate ack during fast recovery. This will also re-enable // the retransmit timer if needed. - if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo { - s.sendData() - } + s.sendData() } // sendSegment sends the specified segment. diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index ef7f5719f..faf0c0ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) { expected++ } } + +// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK +// is received. +func TestSACKUpdateSackedOut(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + ackNum := 0 + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.SackedOut is what we expect. + if state.Sender.SackedOut != 2 && ackNum == 0 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) + } + + if state.Sender.SackedOut != 0 && ackNum == 1 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) + } + if ackNum > 0 { + close(probeDone) + } + ackNum++ + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + sendAndReceive(t, c, 8) + + // ACK for [3-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) + bytesRead := 2 * maxPayload + end := start.Add(seqnum.Size(bytesRead)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + bytesRead += 3 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9f0fb41e3..351a5e4f5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,9 +75,6 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh - if err := ep.LastError(); err != tcpip.ErrAborted { - t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted) - } // Call Connect again to retreive the handshake failure status // and stats updates. @@ -267,7 +264,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) { } // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. - sent := stats.ICMP.V4PacketsSent + sent := stats.ICMP.V4.PacketsSent if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) } @@ -1935,6 +1932,84 @@ func TestFullWindowReceive(t *testing.T) { ) } +// Test the stack receive window advertisement on receiving segments smaller than +// segment overhead. It tests for the right edge of the window to not grow when +// the endpoint is not being read from. +func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize, + Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), + } + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + + c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Bump up the receive buffer size such that, when the receive window grows, + // the scaled window exceeds maxUint16. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) + } + + // Keep the payload size < segment overhead and such that it is a multiple + // of the window scaled value. This enables the test to perform equality + // checks on the incoming receive window. + payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale)) + payloadLen := seqnum.Size(len(payload)) + iss := seqnum.Value(789) + seqNum := iss.Add(1) + + // Send payload to the endpoint and return the advertised receive window + // from the endpoint. + getIncomingRcvWnd := func() uint32 { + c.SendPacket(payload, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + Flags: header.TCPFlagAck, + RcvWnd: 30000, + }) + seqNum = seqNum.Add(payloadLen) + + pkt := c.GetPacket() + return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale + } + + // Read the advertised receive window with the ACK for payload. + rcvWnd := getIncomingRcvWnd() + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Read the data so that the subsequent ACK from the endpoint + // grows the right edge of the window. + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("got Read(nil) = %s", err) + } + + // Check if we have received max uint16 as our advertised + // scaled window now after a read above. + maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) + if got, want := getIncomingRcvWnd(), maxRcv; got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } +} + func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -2532,10 +2607,10 @@ func TestSegmentMerging(t *testing.T) { { "cork", func(ep tcpip.Endpoint) { - ep.SetSockOptBool(tcpip.CorkOption, true) + ep.SocketOptions().SetCorkOption(true) }, func(ep tcpip.Endpoint) { - ep.SetSockOptBool(tcpip.CorkOption, false) + ep.SocketOptions().SetCorkOption(false) }, }, } @@ -2627,7 +2702,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptBool(tcpip.DelayOption, true) + c.EP.SocketOptions().SetDelayOption(true) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -2675,7 +2750,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptBool(tcpip.DelayOption, true) + c.EP.SocketOptions().SetDelayOption(true) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -2708,7 +2783,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOptBool(tcpip.DelayOption, false) + c.EP.SocketOptions().SetDelayOption(false) // Check that data is received. second := c.GetPacket() @@ -2745,8 +2820,8 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, } for _, test := range tests { @@ -3198,6 +3273,11 @@ loop: case tcpip.ErrWouldBlock: select { case <-ch: + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + } + break loop case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for reset to arrive") } @@ -3207,14 +3287,10 @@ loop: t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) } } - // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) - } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) } @@ -4150,7 +4226,7 @@ func TestReadAfterClosedState(t *testing.T) { // Check that peek works. peekBuf := make([]byte, 10) - n, _, err := c.EP.Peek([][]byte{peekBuf}) + n, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { t.Fatalf("Peek failed: %s", err) } @@ -4176,7 +4252,7 @@ func TestReadAfterClosedState(t *testing.T) { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) } - if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { + if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -4193,9 +4269,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4205,9 +4279,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4218,9 +4290,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4233,9 +4303,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4246,9 +4314,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4261,9 +4327,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4656,13 +4720,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { switch network { case "ipv4": case "ipv6": - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) - } + ep.SocketOptions().SetV6Only(true) case "dual": - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { - t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) - } + ep.SocketOptions().SetV6Only(false) default: t.Fatalf("unknown network: '%s'", network) } @@ -4998,9 +5058,7 @@ func TestKeepalive(t *testing.T) { if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) } - if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { - t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) - } + c.EP.SocketOptions().SetKeepAlive(true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -6118,10 +6176,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Introduce a 25ms latency by delaying the first byte. latency := 25 * time.Millisecond time.Sleep(latency) - rawEP.SendPacketWithTS([]byte{1}, tsVal) + // Send an initial payload with atleast segment overhead size. The receive + // window would not grow for smaller segments. + rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal) pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() + time.Sleep(25 * time.Millisecond) // Allocate a large enough payload for the test. @@ -6394,10 +6455,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.T if err != nil { t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) } - gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) - } + gotDelayOption := ep.SocketOptions().GetDelayOption() if gotDelayOption != wantDelayOption { t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) } @@ -7250,9 +7308,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) } - if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { - t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) - } + c.EP.SocketOptions().SetKeepAlive(true) // Set userTimeout to be the duration to be 1 keepalive // probes. Which means that after the first probe is sent diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index e6aa4fc4b..ee55f030c 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -592,9 +592,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } + c.EP.SocketOptions().SetV6Only(v6only) } // GetV6Packet reads a single packet from the link layer endpoint of the context @@ -637,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - NextHeader: uint8(tcp.ProtocolNumber), - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + TransportProtocol: tcp.ProtocolNumber, + HopLimit: 65, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c78549424..153e8c950 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -56,6 +56,8 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 57976d4e3..763d1d654 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -16,8 +16,8 @@ package udp import ( "fmt" + "sync/atomic" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -30,10 +30,11 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry - senderAddress tcpip.FullAddress - packetInfo tcpip.IPPacketInfo - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + senderAddress tcpip.FullAddress + destinationAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 // tos stores either the receiveTOS or receiveTClass value. tos uint8 } @@ -77,6 +78,7 @@ func (s EndpointState) String() string { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. @@ -94,21 +96,20 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` dstPort uint16 - v6only bool ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID - multicastLoop bool portFlags ports.Flags bindToDevice tcpip.NICID - noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -122,17 +123,6 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - // receiveTOS determines if the incoming IPv4 TOS header field is passed - // as ancillary data to ControlMessages on Read. - receiveTOS bool - - // receiveTClass determines if the incoming IPv6 TClass header field is - // passed as ancillary data to ControlMessages on Read. - receiveTClass bool - - // receiveIPPacketInfo determines if the packet info is returned by Read. - receiveIPPacketInfo bool - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -154,9 +144,6 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -188,13 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } + e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -210,6 +198,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue return e } +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState() returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + // UniqueID implements stack.TransportEndpoint.UniqueID. func (e *endpoint) UniqueID() uint64 { return e.uniqueID @@ -235,7 +237,7 @@ func (e *endpoint) Close() { e.mu.Lock() e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { + switch e.EndpointState() { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) @@ -258,10 +260,13 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } // Update the state. - e.state = StateClosed + e.setEndpointState(StateClosed) e.mu.Unlock() @@ -303,24 +308,23 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess HasTimestamp: true, Timestamp: p.timestamp, } - e.mu.RLock() - receiveTOS := e.receiveTOS - receiveTClass := e.receiveTClass - receiveIPPacketInfo := e.receiveIPPacketInfo - e.mu.RUnlock() - if receiveTOS { + if e.ops.GetReceiveTOS() { cm.HasTOS = true cm.TOS = p.tos } - if receiveTClass { + if e.ops.GetReceiveTClass() { cm.HasTClass = true // Although TClass is an 8-bit value it's read in the CMsg as a uint32. cm.TClass = uint32(p.tos) } - if receiveIPPacketInfo { + if e.ops.GetReceivePacketInfo() { cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } + if e.ops.GetReceiveOriginalDstAddress() { + cm.HasOriginalDstAddress = true + cm.OriginalDstAddress = p.destinationAddress + } return p.data.ToView(), cm, nil } @@ -330,7 +334,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // // Returns true for retry if preparation should be retried. func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { - switch e.state { + switch e.EndpointState() { case StateInitial: case StateConnected: return false, nil @@ -352,7 +356,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return true, nil } @@ -367,9 +371,9 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress - if isBroadcastOrMulticast(localAddr) { + if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" } @@ -384,9 +388,9 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) if err != nil { - return stack.Route{}, 0, err + return nil, 0, err } return r, nicID, nil } @@ -429,7 +433,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c to := opts.To e.mu.RLock() - defer e.mu.RUnlock() + lockReleased := false + defer func() { + if lockReleased { + return + } + e.mu.RUnlock() + }() // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { @@ -448,36 +458,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } } - var route *stack.Route - var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) - var dstPort uint16 - if to == nil { - route = &e.route - dstPort = e.dstPort - resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) { - // Promote lock to exclusive if using a shared route, given that it may - // need to change in Route.Resolve() call below. - e.mu.RUnlock() - e.mu.Lock() - - // Recheck state after lock was re-acquired. - if e.state != StateConnected { - err = tcpip.ErrInvalidEndpointState - } - if err == nil && route.IsResolutionRequired() { - ch, err = route.Resolve(waker) - } - - e.mu.Unlock() - e.mu.RLock() - - // Recheck state after lock was re-acquired. - if e.state != StateConnected { - err = tcpip.ErrInvalidEndpointState - } - return - } - } else { + route := e.route + dstPort := e.dstPort + if to != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC @@ -505,9 +488,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } defer r.Release() - route = &r + route = r dstPort = dst.Port - resolve = route.Resolve } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { @@ -515,7 +497,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if route.IsResolutionRequired() { - if ch, err := resolve(nil); err != nil { + if ch, err := route.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { return 0, ch, tcpip.ErrNoLinkAddress } @@ -541,77 +523,46 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { + localPort := e.ID.LocalPort + sendTOS := e.sendTOS + owner := e.owner + noChecksum := e.SocketOptions().GetNoChecksum() + lockReleased = true + e.mu.RUnlock() + + // Do not hold lock when sending as loopback is synchronous and if the UDP + // datagram ends up generating an ICMP response then it can result in a + // deadlock where the ICMP response handling ends up acquiring this endpoint's + // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a + // deadlock if another caller is trying to acquire e.mu in exclusive mode w/ + // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the + // lock can be eventually acquired. + // + // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read + // locking is prohibited. + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = v - e.mu.Unlock() - - case tcpip.NoChecksumOption: - e.mu.Lock() - e.noChecksum = v - e.mu.Unlock() - - case tcpip.ReceiveTOSOption: - e.mu.Lock() - e.receiveTOS = v - e.mu.Unlock() - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrNotSupported - } - - e.mu.Lock() - e.receiveTClass = v - e.mu.Unlock() - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.portFlags.MostRecent = v - e.mu.Unlock() - - case tcpip.ReusePortOption: - e.mu.Lock() - e.portFlags.LoadBalanced = v - e.mu.Unlock() - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } +// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +func (e *endpoint) OnReuseAddressSet(v bool) { + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() +} - e.v6only = v - } - return nil +// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +func (e *endpoint) OnReusePortSet(v bool) { + e.mu.Lock() + e.portFlags.LoadBalanced = v + e.mu.Unlock() } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. @@ -814,90 +765,10 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() } return nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption: - return false, nil - - case tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - return v, nil - - case tcpip.NoChecksumOption: - e.mu.RLock() - v := e.noChecksum - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTOSOption: - e.mu.RLock() - v := e.receiveTOS - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrNotSupported - } - - e.mu.RLock() - v := e.receiveTClass - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil - - case tcpip.ReuseAddressOption: - e.mu.RLock() - v := e.portFlags.MostRecent - e.mu.RUnlock() - - return v, nil - - case tcpip.ReusePortOption: - e.mu.RLock() - v := e.portFlags.LoadBalanced - e.mu.RUnlock() - - return v, nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.mu.RLock() - v := e.v6only - e.mu.RUnlock() - - return v, nil - - case tcpip.AcceptConnOption: - return false, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -972,11 +843,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() - case *tcpip.LingerOption: - e.mu.RLock() - *o = e.linger - e.mu.RUnlock() - default: return tcpip.ErrUnknownProtocolOption } @@ -1036,7 +902,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -1048,7 +914,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.state != StateConnected { + if e.EndpointState() != StateConnected { return nil } var ( @@ -1071,7 +937,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { if err != nil { return err } - e.state = StateBound + e.setEndpointState(StateBound) boundPortFlags = e.boundPortFlags } else { if e.ID.LocalPort != 0 { @@ -1079,14 +945,14 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - e.state = StateInitial + e.setEndpointState(StateInitial) } e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() - e.route = stack.Route{} + e.route = nil e.dstPort = 0 return nil @@ -1104,7 +970,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicID := addr.NIC var localPort uint16 - switch e.state { + switch e.EndpointState() { case StateInitial: case StateBound, StateConnected: localPort = e.ID.LocalPort @@ -1139,7 +1005,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { RemoteAddress: r.RemoteAddress, } - if e.state == StateInitial { + if e.EndpointState() == StateInitial { id.LocalAddress = r.LocalAddress } @@ -1147,7 +1013,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv4ProtocolNumber, header.IPv6ProtocolNumber, @@ -1173,7 +1039,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.RegisterNICID = nicID e.effectiveNetProtos = netProtos - e.state = StateConnected + e.setEndpointState(StateConnected) e.rcvMu.Lock() e.rcvReady = true @@ -1195,7 +1061,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. - if e.state != StateBound && e.state != StateConnected { + if state := e.EndpointState(); state != StateBound && state != StateConnected { return tcpip.ErrNotConnected } @@ -1246,7 +1112,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -1259,7 +1125,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // wildcard (empty) address, and this is an IPv6 endpoint with v6only // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv6ProtocolNumber, header.IPv4ProtocolNumber, @@ -1267,7 +1133,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicID := addr.NIC - if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { // A local unicast address was specified, verify that it's valid. nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicID == 0 { @@ -1290,7 +1156,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { e.effectiveNetProtos = netProtos // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) e.rcvMu.Lock() e.rcvReady = true @@ -1322,7 +1188,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() addr := e.ID.LocalAddress - if e.state == StateConnected { + if e.EndpointState() == StateConnected { addr = e.route.LocalAddress } @@ -1338,7 +1204,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != StateConnected { + if e.EndpointState() != StateConnected { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1393,7 +1259,6 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. @@ -1402,6 +1267,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } + // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap + // packets at "Parse" instead of when handling a packet. + pkt.Data.CapLength(int(hdr.PayloadLength())) + if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() @@ -1435,7 +1304,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), + }, + destinationAddress: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: header.UDP(hdr).DestinationPort(), }, } packet.data = pkt.Data @@ -1470,25 +1344,20 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { - e.mu.RLock() - if e.state == StateConnected { + if e.EndpointState() == StateConnected { e.lastErrorMu.Lock() e.lastError = tcpip.ErrConnectionRefused e.lastErrorMu.Unlock() - e.mu.RUnlock() e.waiterQueue.Notify(waiter.EventErr) return } - e.mu.RUnlock() } } // State implements tcpip.Endpoint.State. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. @@ -1508,14 +1377,16 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements tcpip.Endpoint.Wait. func (*endpoint) Wait() {} -func isBroadcastOrMulticast(a tcpip.Address) bool { - return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 858c99a45..13b72dc88 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) { } } - if e.state != StateBound && e.state != StateConnected { + state := e.EndpointState() + if state != StateBound && state != StateConnected { return } @@ -113,12 +114,12 @@ func (e *endpoint) Resume(s *stack.Stack) { } var err *tcpip.Error - if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) + if state == StateConnected { + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { panic(err) } - } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 764ad0857..08980c298 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -32,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,6 +56,7 @@ const ( stackPort = 1234 testAddr = "\x0a\x00\x00\x02" testPort = 4096 + invalidPort = 8192 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast @@ -295,7 +298,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: true, }) } @@ -360,9 +364,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetV6Only(true) } else if flow.isBroadcast() { c.ep.SocketOptions().SetBroadcast(true) } @@ -451,12 +453,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. @@ -972,7 +974,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { // provided. func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() - return testWriteInternal(c, flow, true, checkers...) + return testWriteAndVerifyInternal(c, flow, true, checkers...) } // testWriteWithoutDestination sends a packet of the given test flow from the @@ -981,10 +983,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker // checker functions provided. func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() - return testWriteInternal(c, flow, false, checkers...) + return testWriteAndVerifyInternal(c, flow, false, checkers...) } -func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { +func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View { c.t.Helper() // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() @@ -1006,6 +1008,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } c.checkEndpointWriteStats(1, epstats, err) + return payload +} + +func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + payload := testWriteNoVerify(c, flow, setDest) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -1150,6 +1158,39 @@ func TestV4WriteOnConnected(t *testing.T) { testWriteWithoutDestination(c, unicastV4) } +func TestWriteOnConnectedInvalidPort(t *testing.T) { + protocols := map[string]tcpip.NetworkProtocolNumber{ + "ipv4": ipv4.ProtocolNumber, + "ipv6": ipv6.ProtocolNumber, + } + for name, pn := range protocols { + t.Run(name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(pn) + if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, + } + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + if err != nil { + c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) + } + if got, want := n, int64(len(payload)); got != want { + c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) + } + + if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused { + c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + } + }) + } +} + // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket // that is bound to a V4 multicast address. func TestWriteOnBoundToV4Multicast(t *testing.T) { @@ -1372,9 +1413,7 @@ func TestReadIPPacketInfo(t *testing.T) { } } - if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { - t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) - } + c.ep.SocketOptions().SetReceivePacketInfo(true) testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ NIC: 1, @@ -1389,6 +1428,93 @@ func TestReadIPPacketInfo(t *testing.T) { } } +func TestReadRecvOriginalDstAddr(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedOriginalDstAddr tcpip.FullAddress + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%#v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) + } + } + + c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) + + testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1412,16 +1538,12 @@ func TestNoChecksum(t *testing.T) { c.createEndpointForFlow(flow) // Disable the checksum generation. - if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { - t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetNoChecksum(true) // This option is effective on IPv4 only. testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) // Enable the checksum generation. - if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { - t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetNoChecksum(false) testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) }) } @@ -1591,13 +1713,15 @@ func TestSetTClass(t *testing.T) { } func TestReceiveTosTClass(t *testing.T) { + const RcvTOSOpt = "ReceiveTosOption" + const RcvTClassOpt = "ReceiveTClassOption" + testCases := []struct { - name string - getReceiveOption tcpip.SockOptBool - tests []testFlow + name string + tests []testFlow }{ - {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, - {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, + {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, } for _, testCase := range testCases { for _, flow := range testCase.tests { @@ -1606,29 +1730,32 @@ func TestReceiveTosTClass(t *testing.T) { defer c.cleanup() c.createEndpointForFlow(flow) - option := testCase.getReceiveOption name := testCase.name - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) + var optionGetter func() bool + var optionSetter func(bool) + switch name { + case RcvTOSOpt: + optionGetter = c.ep.SocketOptions().GetReceiveTOS + optionSetter = c.ep.SocketOptions().SetReceiveTOS + case RcvTClassOpt: + optionGetter = c.ep.SocketOptions().GetReceiveTClass + optionSetter = c.ep.SocketOptions().SetReceiveTClass + default: + t.Fatalf("unkown test variant: %s", name) } + + // Verify that setting and reading the option works. + v := optionGetter() // Test for expected default value. if v != false { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) } want := true - if err := c.ep.SetSockOptBool(option, want); err != nil { - c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) - } - - got, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) - } + optionSetter(want) + got := optionGetter() if got != want { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) } @@ -1638,10 +1765,10 @@ func TestReceiveTosTClass(t *testing.T) { if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %s", err) } - switch option { - case tcpip.ReceiveTClassOption: + switch name { + case RcvTClassOpt: testRead(c, flow, checker.ReceiveTClass(testTOS)) - case tcpip.ReceiveTOSOption: + case RcvTOSOpt: testRead(c, flow, checker.ReceiveTOS(testTOS)) default: t.Fatalf("unknown test variant: %s", name) @@ -1788,27 +1915,31 @@ func TestV4UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize - wantLen := len(payload) + wantPayloadLen := len(payload) if tc.largePayload { // To work out the data size we need to simulate what the sender would // have done. The wanted size is the total available minus the sum of // the headers in the UDP AND ICMP packets, given that we know the test // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength + wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength } // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. // Add back the two headers within the payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) - + payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength)) origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + wantDgramLen := wantPayloadLen + header.UDPMinimumSize + + if got, want := len(origDgram), wantDgramLen; got != want { + t.Fatalf("got len(origDgram) = %d, want = %d", got, want) } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) + // Correct UDP length to access payload. + origDgram.SetLength(uint16(wantDgramLen)) + + if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) { + t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want) } }) } @@ -1883,20 +2014,23 @@ func TestV6UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv6(hdr.Payload()) payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) + wantPayloadLen := len(payload) if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize } + wantDgramLen := wantPayloadLen + header.UDPMinimumSize // In case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + payloadIPHeader.SetPayloadLength(uint16(wantDgramLen)) origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want { + t.Fatalf("got len(origDgram) = %d, want = %d", got, want) } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) + // Correct UDP length to access payload. + origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize)) + if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" { + t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff) } }) } @@ -1955,12 +2089,12 @@ func TestShortHeader(t *testing.T) { // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(udpSize), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. @@ -2409,3 +2543,67 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { }) } } + +func TestReceiveShortLength(t *testing.T) { + flows := []testFlow{unicastV4, unicastV6} + for _, flow := range flows { + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err) + } + + payload := newPayload() + extraBytes := []byte{1, 2, 3, 4} + h := flow.header4Tuple(incoming) + var buf buffer.View + var proto tcpip.NetworkProtocolNumber + + // Build packets with extra bytes not accounted for in the UDP length + // field. + var udp header.UDP + if flow.isV4() { + buf = c.buildV4Packet(payload, &h) + buf = append(buf, extraBytes...) + ip := header.IPv4(buf) + ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes))) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + proto = ipv4.ProtocolNumber + udp = ip.Payload() + } else { + buf = c.buildV6Packet(payload, &h) + buf = append(buf, extraBytes...) + ip := header.IPv6(buf) + ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes))) + proto = ipv6.ProtocolNumber + udp = ip.Payload() + } + + if diff := cmp.Diff(payload, udp.Payload()); diff != "" { + t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff) + } + + c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + // Try to receive the data. + v, _, err := c.ep.Read(nil) + if err != nil { + t.Fatalf("c.ep.Read(nil): %s", err) + } + + // Check the payload is read back without extra bytes. + if diff := cmp.Diff(buffer.View(payload), v); diff != "" { + t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff) + } + }) + } +} |