From 313c767b0001bf6271405f1b765b60a334d6e911 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Tue, 27 Aug 2019 18:53:34 -0700 Subject: Populate link address cache at dispatch This allows the stack to learn remote link addresses on incoming packets, reducing the need to ARP to send responses. This also reduces the number of round trips to the system clock, since that may also prove to be performance-sensitive. Fixes #739. PiperOrigin-RevId: 265815816 --- pkg/tcpip/header/ipv6.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'pkg/tcpip/header') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 31be42ce0..bc4e56535 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -27,7 +27,7 @@ const ( nextHdr = 6 hopLimit = 7 v6SrcAddr = 8 - v6DstAddr = 24 + v6DstAddr = v6SrcAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -119,13 +119,13 @@ func (b IPv6) Payload() []byte { // SourceAddress returns the "source address" field of the ipv6 header. func (b IPv6) SourceAddress() tcpip.Address { - return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize]) + return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize]) } // DestinationAddress returns the "destination address" field of the ipv6 // header. func (b IPv6) DestinationAddress() tcpip.Address { - return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize]) + return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize]) } // Checksum implements Network.Checksum. Given that IPv6 doesn't have a @@ -153,13 +153,13 @@ func (b IPv6) SetPayloadLength(payloadLength uint16) { // SetSourceAddress sets the "source address" field of the ipv6 header. func (b IPv6) SetSourceAddress(addr tcpip.Address) { - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr) + copy(b[v6SrcAddr:][:IPv6AddressSize], addr) } // SetDestinationAddress sets the "destination address" field of the ipv6 // header. func (b IPv6) SetDestinationAddress(addr tcpip.Address) { - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr) + copy(b[v6DstAddr:][:IPv6AddressSize], addr) } // SetNextHeader sets the value of the "next header" field of the ipv6 header. @@ -178,8 +178,8 @@ func (b IPv6) Encode(i *IPv6Fields) { b.SetPayloadLength(i.PayloadLength) b[nextHdr] = i.NextHeader b[hopLimit] = i.HopLimit - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr) - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr) + b.SetSourceAddress(i.SrcAddr) + b.SetDestinationAddress(i.DstAddr) } // IsValid performs basic validation on the packet. -- cgit v1.2.3 From 3789c34b22e7a7466149bfbeedf05bf49188130c Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 3 Sep 2019 15:59:58 -0700 Subject: Make UDP traceroute work. Adds support to generate Port Unreachable messages for UDP datagrams received on a port for which there is no valid endpoint. Fixes #703 PiperOrigin-RevId: 267034418 --- pkg/tcpip/checker/checker.go | 100 +++++++++++++ pkg/tcpip/header/icmpv4.go | 71 +++++++++- pkg/tcpip/header/icmpv6.go | 88 +++++++++++- pkg/tcpip/header/ipv4.go | 33 +++-- pkg/tcpip/header/ipv6.go | 18 +-- pkg/tcpip/network/ip_test.go | 12 +- pkg/tcpip/network/ipv4/icmp.go | 4 +- pkg/tcpip/network/ipv6/icmp.go | 31 +--- pkg/tcpip/network/ipv6/icmp_test.go | 4 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/icmp_rate_limit.go | 86 ++++++++++++ pkg/tcpip/stack/nic.go | 12 +- pkg/tcpip/stack/registration.go | 2 +- pkg/tcpip/stack/route.go | 5 + pkg/tcpip/stack/stack.go | 36 +++++ pkg/tcpip/stack/transport_test.go | 2 +- pkg/tcpip/tcpip.go | 8 ++ pkg/tcpip/transport/icmp/endpoint.go | 19 ++- pkg/tcpip/transport/icmp/protocol.go | 11 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/tcp/testing/context/context.go | 8 +- pkg/tcpip/transport/udp/protocol.go | 101 ++++++++++++- pkg/tcpip/transport/udp/udp_test.go | 156 ++++++++++++++++++++- 23 files changed, 717 insertions(+), 94 deletions(-) create mode 100644 pkg/tcpip/stack/icmp_rate_limit.go (limited to 'pkg/tcpip/header') diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index afcabd51d..096ad71ab 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -586,3 +586,103 @@ func Payload(want []byte) TransportChecker { } } } + +// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and +// potentially additional ICMPv4 header fields. +func ICMPv4(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) + } + + icmp := header.ICMPv4(last.Payload()) + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// ICMPv4Type creates a checker that checks the ICMPv4 Type field. +func ICMPv4Type(want header.ICMPv4Type) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + } + if got := icmpv4.Type(); got != want { + t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + } + } +} + +// ICMPv4Code creates a checker that checks the ICMPv4 Code field. +func ICMPv4Code(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + } + if got := icmpv4.Code(); got != want { + t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + } + } +} + +// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and +// potentially additional ICMPv6 header fields. +func ICMPv6(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) + } + + icmp := header.ICMPv6(last.Payload()) + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// ICMPv6Type creates a checker that checks the ICMPv6 Type field. +func ICMPv6Type(want header.ICMPv6Type) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + } + if got := icmpv6.Type(); got != want { + t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + } + } +} + +// ICMPv6Code creates a checker that checks the ICMPv6 Code field. +func ICMPv6Code(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + } + if got := icmpv6.Code(); got != want { + t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + } + } +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c52c0d851..0cac6c0a5 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. @@ -25,13 +26,29 @@ type ICMPv4 []byte const ( // ICMPv4PayloadOffset defines the start of ICMP payload. - ICMPv4PayloadOffset = 4 + ICMPv4PayloadOffset = 8 // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. ICMPv4MinimumSize = 8 // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 + + // icmpv4ChecksumOffset is the offset of the checksum field + // in an ICMPv4 message. + icmpv4ChecksumOffset = 2 + + // icmpv4MTUOffset is the offset of the MTU field + // in a ICMPv4FragmentationNeeded message. + icmpv4MTUOffset = 6 + + // icmpv4IdentOffset is the offset of the ident field + // in a ICMPv4EchoRequest/Reply message. + icmpv4IdentOffset = 4 + + // icmpv4SequenceOffset is the offset of the sequence field + // in a ICMPv4EchoRequest/Reply message. + icmpv4SequenceOffset = 6 ) // ICMPv4Type is the ICMP type field described in RFC 792. @@ -72,12 +89,12 @@ func (b ICMPv4) SetCode(c byte) { b[1] = c } // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { - return binary.BigEndian.Uint16(b[2:]) + return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) } // SetChecksum sets the ICMP checksum field. func (b ICMPv4) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[2:], checksum) + binary.BigEndian.PutUint16(b[icmpv4ChecksumOffset:], checksum) } // SourcePort implements Transport.SourcePort. @@ -102,3 +119,51 @@ func (ICMPv4) SetDestinationPort(uint16) { func (b ICMPv4) Payload() []byte { return b[ICMPv4PayloadOffset:] } + +// MTU retrieves the MTU field from an ICMPv4 message. +func (b ICMPv4) MTU() uint16 { + return binary.BigEndian.Uint16(b[icmpv4MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv4 message. +func (b ICMPv4) SetMTU(mtu uint16) { + binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv4 message. +func (b ICMPv4) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv4IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv4 message. +func (b ICMPv4) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident) +} + +// Sequence retrieves the Sequence field from an ICMPv4 message. +func (b ICMPv4) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv4 message. +func (b ICMPv4) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence) +} + +// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, +// and payload. +func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := uint16(0) + for _, v := range vv.Views() { + xsum = Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 3cc57e234..1125a7d14 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. @@ -25,14 +26,18 @@ type ICMPv6 []byte const ( // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. - ICMPv6MinimumSize = 4 + ICMPv6MinimumSize = 8 + + // ICMPv6PayloadOffset is the offset of the payload in an + // ICMP packet. + ICMPv6PayloadOffset = 8 // ICMPv6ProtocolNumber is the ICMP transport protocol number. ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58 // ICMPv6NeighborSolicitMinimumSize is the minimum size of a // neighbor solicitation packet. - ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16 + ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16 // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. ICMPv6NeighborAdvertSize = 32 @@ -42,11 +47,27 @@ const ( // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP // destination unreachable packet. - ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4 + ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP // packet-too-big packet. - ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4 + ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + + // icmpv6ChecksumOffset is the offset of the checksum field + // in an ICMPv6 message. + icmpv6ChecksumOffset = 2 + + // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6 + // PacketTooBig message. + icmpv6MTUOffset = 4 + + // icmpv6IdentOffset is the offset of the ident field + // in a ICMPv6 Echo Request/Reply message. + icmpv6IdentOffset = 4 + + // icmpv6SequenceOffset is the offset of the sequence field + // in a ICMPv6 Echo Request/Reply message. + icmpv6SequenceOffset = 6 ) // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. @@ -89,12 +110,12 @@ func (b ICMPv6) SetCode(c byte) { b[1] = c } // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { - return binary.BigEndian.Uint16(b[2:]) + return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } // SetChecksum calculates and sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[2:], checksum) + binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } // SourcePort implements Transport.SourcePort. @@ -115,7 +136,60 @@ func (ICMPv6) SetSourcePort(uint16) { func (ICMPv6) SetDestinationPort(uint16) { } +// MTU retrieves the MTU field from an ICMPv6 message. +func (b ICMPv6) MTU() uint32 { + return binary.BigEndian.Uint32(b[icmpv6MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv6 message. +func (b ICMPv6) SetMTU(mtu uint32) { + binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv6 message. +func (b ICMPv6) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv6IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv6 message. +func (b ICMPv6) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident) +} + +// Sequence retrieves the Sequence field from an ICMPv6 message. +func (b ICMPv6) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv6 message. +func (b ICMPv6) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) +} + // Payload implements Transport.Payload. func (b ICMPv6) Payload() []byte { - return b[ICMPv6MinimumSize:] + return b[ICMPv6PayloadOffset:] +} + +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// IPv6 src/dst addresses and the payload. +func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := Checksum([]byte(src), 0) + xsum = Checksum([]byte(dst), xsum) + var upperLayerLength [4]byte + binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) + xsum = Checksum(upperLayerLength[:], xsum) + xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum) + for _, v := range vv.Views() { + xsum = Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 17fc9c68e..554632a64 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -21,16 +21,18 @@ import ( ) const ( - versIHL = 0 - tos = 1 - totalLen = 2 - id = 4 - flagsFO = 6 - ttl = 8 - protocol = 9 - checksum = 10 - srcAddr = 12 - dstAddr = 16 + versIHL = 0 + tos = 1 + // IPv4TotalLenOffset is the offset of the total length field in the + // IPv4 header. + IPv4TotalLenOffset = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksum = 10 + srcAddr = 12 + dstAddr = 16 ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the @@ -103,6 +105,11 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP + // packet that every IPv4 capable host must be able to + // process/reassemble. + IPv4MinimumProcessableDatagramSize = 576 ) // Flags that may be set in an IPv4 packet. @@ -163,7 +170,7 @@ func (b IPv4) FragmentOffset() uint16 { // TotalLength returns the "total length" field of the ipv4 header. func (b IPv4) TotalLength() uint16 { - return binary.BigEndian.Uint16(b[totalLen:]) + return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } // Checksum returns the checksum field of the ipv4 header. @@ -209,7 +216,7 @@ func (b IPv4) SetTOS(v uint8, _ uint32) { // SetTotalLength sets the "total length" field of the ipv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { - binary.BigEndian.PutUint16(b[totalLen:], totalLength) + binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } // SetChecksum sets the checksum field of the ipv4 header. @@ -265,7 +272,7 @@ func (b IPv4) Encode(i *IPv4Fields) { // packets are produced. func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) { b.SetTotalLength(totalLength) - checksum := Checksum(b[totalLen:totalLen+2], partialChecksum) + checksum := Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum) b.SetChecksum(^checksum) } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index bc4e56535..093850e25 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -22,12 +22,14 @@ import ( ) const ( - versTCFL = 0 - payloadLen = 4 - nextHdr = 6 - hopLimit = 7 - v6SrcAddr = 8 - v6DstAddr = v6SrcAddr + IPv6AddressSize + versTCFL = 0 + // IPv6PayloadLenOffset is the offset of the PayloadLength field in + // IPv6 header. + IPv6PayloadLenOffset = 4 + nextHdr = 6 + hopLimit = 7 + v6SrcAddr = 8 + v6DstAddr = v6SrcAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -94,7 +96,7 @@ var IPv6EmptySubnet = func() tcpip.Subnet { // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { - return binary.BigEndian.Uint16(b[payloadLen:]) + return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) } // HopLimit returns the value of the "hop limit" field of the ipv6 header. @@ -148,7 +150,7 @@ func (b IPv6) SetTOS(t uint8, l uint32) { // SetPayloadLength sets the "payload length" field of the ipv6 header. func (b IPv6) SetPayloadLength(payloadLength uint16) { - binary.BigEndian.PutUint16(b[payloadLen:], payloadLength) + binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength) } // SetSourceAddress sets the "source address" field of the ipv6 header. diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6bbfcd97f..4b3bd74fa 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -319,7 +319,8 @@ func TestIPv4ReceiveControl(t *testing.T) { icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) icmp.SetType(header.ICMPv4DstUnreachable) icmp.SetCode(c.code) - copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef}) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) // Create the inner IPv4 header. ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) @@ -539,7 +540,7 @@ func TestIPv6ReceiveControl(t *testing.T) { defer ep.Close() - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4 + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize } @@ -559,10 +560,11 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) icmp.SetType(c.typ) icmp.SetCode(c.code) - copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) ip.Encode(&header.IPv6Fields{ PayloadLength: 100, NextHeader: 10, @@ -574,7 +576,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Build the fragmentation header if needed. if c.fragmentOffset != nil { ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) frag.Encode(&header.IPv6FragmentFields{ NextHeader: 10, FragmentOffset: *c.fragmentOffset, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 497164cbb..a25756443 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,8 +15,6 @@ package ipv4 import ( - "encoding/binary" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -117,7 +115,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V e.handleControl(stack.ControlPortUnreachable, 0, vv) case header.ICMPv4FragmentationNeeded: - mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:])) + mtu := uint32(h.MTU()) e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 1689af16f..346de9ae3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,8 +15,6 @@ package ipv6 import ( - "encoding/binary" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -82,7 +80,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V return } vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:]) + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) case header.ICMPv6DstUnreachable: @@ -130,7 +128,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { sent.Dropped.Increment() @@ -162,7 +160,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { sent.Dropped.Increment() return @@ -233,7 +231,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr pkt[icmpV6LengthOffset] = 1 copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) length := uint16(hdr.UsedLength()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -272,24 +270,3 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } return "", false } - -func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := header.Checksum([]byte(src), 0) - xsum = header.Checksum([]byte(dst), xsum) - var upperLayerLength [4]byte - binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) - xsum = header.Checksum(upperLayerLength[:], xsum) - xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum) - for _, v := range vv.Views() { - xsum = header.Checksum(v, xsum) - } - - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^header.Checksum(h, xsum) - h[2], h[3] = h2, h3 - - return xsum -} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d0dc72506..227a65cf2 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -153,7 +153,7 @@ func TestICMPCounts(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) pkt := header.ICMPv6(hdr.Prepend(typ.size)) pkt.SetType(typ.typ) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) handleIPv6Payload(hdr) } @@ -321,7 +321,7 @@ func TestLinkResolution(t *testing.T) { hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) payload := tcpip.SlicePayload(hdr.View()) // We can't send our payload directly over the route because that diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index b692c60ce..788de3dfe 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", "nic.go", @@ -42,6 +43,7 @@ go_library( "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go new file mode 100644 index 000000000..f8156be47 --- /dev/null +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -0,0 +1,86 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "sync" + + "golang.org/x/time/rate" +) + +const ( + // icmpLimit is the default maximum number of ICMP messages permitted by this + // rate limiter. + icmpLimit = 1000 + + // icmpBurst is the default number of ICMP messages that can be sent in a single + // burst. + icmpBurst = 50 +) + +// ICMPRateLimiter is a global rate limiter that controls the generation of +// ICMP messages generated by the stack. +type ICMPRateLimiter struct { + mu sync.RWMutex + l *rate.Limiter +} + +// NewICMPRateLimiter returns a global rate limiter for controlling the rate +// at which ICMP messages are generated by the stack. +func NewICMPRateLimiter() *ICMPRateLimiter { + return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)} +} + +// Allow returns true if we are allowed to send at least 1 message at the +// moment. +func (i *ICMPRateLimiter) Allow() bool { + i.mu.RLock() + allow := i.l.Allow() + i.mu.RUnlock() + return allow +} + +// Limit returns the maximum number of ICMP messages that can be sent in one +// second. +func (i *ICMPRateLimiter) Limit() rate.Limit { + i.mu.RLock() + defer i.mu.RUnlock() + return i.l.Limit() +} + +// SetLimit sets the maximum number of ICMP messages that can be sent in one +// second. +func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) { + i.mu.RLock() + defer i.mu.RUnlock() + i.l.SetLimit(newLimit) +} + +// Burst returns how many ICMP messages can be sent at any single instant. +func (i *ICMPRateLimiter) Burst() int { + i.mu.RLock() + defer i.mu.RUnlock() + return i.l.Burst() +} + +// SetBurst sets the maximum number of ICMP messages allowed at any single +// instant. +// +// NOTE: Changing Burst causes the underlying rate limiter to be recreated. +func (i *ICMPRateLimiter) SetBurst(burst int) { + i.mu.Lock() + i.l = rate.NewLimiter(i.l.Limit(), burst) + i.mu.Unlock() +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f947b55db..ae56e0ffd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -679,7 +679,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { n.stack.stats.MalformedRcvdPackets.Increment() } } @@ -720,6 +720,11 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Stack returns the instance of the Stack that owns this NIC. +func (n *NIC) Stack() *Stack { + return n.stack +} + type networkEndpointKind int32 const ( @@ -823,3 +828,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool { } } } + +// stack returns the Stack instance that owns the underlying endpoint. +func (r *referencedNetworkEndpoint) stack() *Stack { + return r.nic.stack +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2037eef9f..67b70b2ee 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -109,7 +109,7 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index e52cdd674..5c8b7977a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -217,3 +217,8 @@ func (r *Route) MakeLoopedRoute() Route { } return l } + +// Stack returns the instance of the Stack that owns this route. +func (r *Route) Stack() *Stack { + return r.ref.stack() +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d69162ba1..1d5e84a8b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -28,6 +28,7 @@ import ( "sync" "time" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -389,6 +390,10 @@ type Stack struct { // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. resumableEndpoints []ResumableEndpoint + + // icmpRateLimiter is a global rate limiter for all ICMP messages generated + // by the stack. + icmpRateLimiter *ICMPRateLimiter } // Options contains optional Stack configuration. @@ -434,6 +439,7 @@ func New(network []string, transport []string, opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, raw: opts.Raw, + icmpRateLimiter: NewICMPRateLimiter(), } // Add specified network protocols. @@ -1215,3 +1221,33 @@ func (s *Stack) IPTables() iptables.IPTables { func (s *Stack) SetIPTables(ipt iptables.IPTables) { s.tables = ipt } + +// ICMPLimit returns the maximum number of ICMP messages that can be sent +// in one second. +func (s *Stack) ICMPLimit() rate.Limit { + return s.icmpRateLimiter.Limit() +} + +// SetICMPLimit sets the maximum number of ICMP messages that be sent +// in one second. +func (s *Stack) SetICMPLimit(newLimit rate.Limit) { + s.icmpRateLimiter.SetLimit(newLimit) +} + +// ICMPBurst returns the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) ICMPBurst() int { + return s.icmpRateLimiter.Burst() +} + +// SetICMPBurst sets the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) SetICMPBurst(burst int) { + s.icmpRateLimiter.SetBurst(burst) +} + +// AllowICMPMessage returns true if we the rate limiter allows at least one +// ICMP message to be sent at this instant. +func (s *Stack) AllowICMPMessage() bool { + return s.icmpRateLimiter.Allow() +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5335897f5..ca185279e 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -251,7 +251,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 8f9b86cce..05aa42c98 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -720,6 +720,10 @@ type ICMPv4SentPacketStats struct { // Dropped is the total number of ICMPv4 packets dropped due to link // layer errors. Dropped *StatCounter + + // RateLimited is the total number of ICMPv6 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter } // ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats. @@ -738,6 +742,10 @@ type ICMPv6SentPacketStats struct { // Dropped is the total number of ICMPv6 packets dropped due to link // layer errors. Dropped *StatCounter + + // RateLimited is the total number of ICMPv6 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter } // ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 451d3880e..e1f622af6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,7 +15,6 @@ package icmp import ( - "encoding/binary" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -368,14 +367,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // Set the ident to the user-specified port. Sequence number should - // already be set by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident) - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(icmpv4, data) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + icmpv4.SetIdent(ident) data = data[header.ICMPv4MinimumSize:] // Linux performs these basic checks. @@ -394,14 +392,13 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // Set the ident. Sequence number is provided by the user. - binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident) - - hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength())) + hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) copy(icmpv6, data) - data = data[header.ICMPv6EchoMinimumSize:] + // Set the ident. Sequence number is provided by the user. + icmpv6.SetIdent(ident) + data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return tcpip.ErrInvalidEndpointState diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 7fdba5d56..1eb790932 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -23,7 +23,6 @@ package icmp import ( - "encoding/binary" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -92,7 +91,7 @@ func (p *protocol) MinimumPacketSize() int { case ProtocolNumber4: return header.ICMPv4MinimumSize case ProtocolNumber6: - return header.ICMPv6EchoMinimumSize + return header.ICMPv6MinimumSize } panic(fmt.Sprint("unknown protocol number: ", p.number)) } @@ -101,16 +100,18 @@ func (p *protocol) MinimumPacketSize() int { func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { switch p.number { case ProtocolNumber4: - return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil + hdr := header.ICMPv4(v) + return 0, hdr.Ident(), nil case ProtocolNumber6: - return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil + hdr := header.ICMPv6(v) + return 0, hdr.Ident(), nil } panic(fmt.Sprint("unknown protocol number: ", p.number)) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index ee04dcfcc..2a13b2022 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -129,7 +129,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 272481aa0..18c707a57 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -267,7 +267,7 @@ func (c *Context) GetPacketNonBlocking() []byte { // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2)) + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) if len(buf) > maxTotalSize { buf = buf[:maxTotalSize] } @@ -286,9 +286,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) icmp.SetType(typ) icmp.SetCode(code) - - copy(icmp[header.ICMPv4PayloadOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2) + const icmpv4VariableHeaderOffset = 4 + copy(icmp[icmpv4VariableHeaderOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f76e7fbe1..068d9a272 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -69,7 +69,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { + // Get the header then trim it from the view. + hdr := header.UDP(vv.First()) + if int(hdr.Length()) > vv.Size() { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + // TODO(b/129426613): only send an ICMP message if UDP checksum is valid. + + // Only send ICMP error if the address is not a multicast/broadcast + // v4/v6 address or the source is not the unspecified address. + // + // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { + return true + } + + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported; or + // + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + switch len(id.LocalAddress) { + case header.IPv4AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() + return true + } + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 where it mentioned that at least 8 + // bytes of the payload must be included. Today linux and other + // systems implement the] RFC1812 definition and not the original + // RFC 1122 requirement. + mtu := int(r.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + payloadLen := len(netHeader) + vv.Size() + if payloadLen > available { + payloadLen = available + } + + payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) + payload.Append(vv) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4DstUnreachable) + pkt.SetCode(header.ICMPv4PortUnreachable) + pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) + r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + + case header.IPv6AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() + return true + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(r.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize + available := int(mtu) - headerLen + payloadLen := len(netHeader) + vv.Size() + if payloadLen > available { + payloadLen = available + } + payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) + payload.Append(vv) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) + pkt.SetType(header.ICMPv6DstUnreachable) + pkt.SetCode(header.ICMPv6PortUnreachable) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) + r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 9da6edce2..995d6e8a1 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -461,7 +461,11 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { } func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) + return newMinPayload(30) +} + +func newMinPayload(minSize int) []byte { + b := make([]byte, minSize+rand.Intn(100)) for i := range b { b[i] = byte(rand.Intn(256)) } @@ -1238,3 +1242,153 @@ func TestMulticastInterfaceOption(t *testing.T) { }) } } + +// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV4UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true, will result in a payload large enough + // so that the final generated IPv4 packet is larger than + // header.IPv4MinimumProcessableDatagramSize. + largePayload bool + }{ + {unicastV4, true, false}, + {unicastV4, true, true}, + {multicastV4, false, false}, + {multicastV4, false, true}, + {broadcast, false, false}, + {broadcast, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(576) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + select { + case p := <-c.linkEP.C: + t.Fatalf("unexpected packet received: %+v", p) + case <-time.After(1 * time.Second): + return + } + } + + select { + case p := <-c.linkEP.C: + var pkt []byte + pkt = append(pkt, p.Header...) + pkt = append(pkt, p.Payload...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } + + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) + } + case <-time.After(1 * time.Second): + t.Fatalf("packet wasn't written out") + } + }) + } +} + +// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV6UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true will result in a payload large enough to + // create an IPv6 packet > header.IPv6MinimumMTU bytes. + largePayload bool + }{ + {unicastV6, true, false}, + {unicastV6, true, true}, + {multicastV6, false, false}, + {multicastV6, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(1280) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + select { + case p := <-c.linkEP.C: + t.Fatalf("unexpected packet received: %+v", p) + case <-time.After(1 * time.Second): + return + } + } + + select { + case p := <-c.linkEP.C: + var pkt []byte + pkt = append(pkt, p.Header...) + pkt = append(pkt, p.Payload...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) + } + case <-time.After(1 * time.Second): + t.Fatalf("packet wasn't written out") + } + }) + } +} -- cgit v1.2.3 From a8943325db43d04be8a10157f6d3f3180e5170a5 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 9 Sep 2019 12:04:33 -0700 Subject: Join IPv6 all-nodes and solicited-node multicast addresses where appropriate. The IPv6 all-nodes multicast address will be joined on NIC enable, and the appropriate IPv6 solicited-node multicast address will be joined when IPv6 addresses are added. Tests: Test receiving packets destined to the IPv6 link-local all-nodes multicast address and the IPv6 solicted node address of an added IPv6 address. PiperOrigin-RevId: 268047073 --- pkg/tcpip/header/ipv6.go | 7 ++ pkg/tcpip/network/ipv6/BUILD | 2 + pkg/tcpip/network/ipv6/icmp_test.go | 6 - pkg/tcpip/network/ipv6/ipv6_test.go | 215 ++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/nic.go | 69 +++++++++++- pkg/tcpip/stack/stack.go | 6 +- 6 files changed, 290 insertions(+), 15 deletions(-) create mode 100644 pkg/tcpip/network/ipv6/ipv6_test.go (limited to 'pkg/tcpip/header') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 093850e25..e606e3463 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -76,6 +76,13 @@ const ( // IPv6Version is the version of the ipv6 protocol. IPv6Version = 6 + // IPv6AllNodesMulticastAddress is a link-local multicast group that + // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all nodes on a link. + // + // The address is ff02::1. + IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index c71b69123..d02ca0227 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -25,6 +25,7 @@ go_test( size = "small", srcs = [ "icmp_test.go", + "ipv6_test.go", "ndp_test.go", ], embed = [":ipv6"], @@ -36,6 +37,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a6a1a5232..653d984e9 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -222,9 +222,6 @@ func newTestContext(t *testing.T) *testContext { if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { - t.Fatalf("AddAddress sn lladdr0: %v", err) - } c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) @@ -234,9 +231,6 @@ func newTestContext(t *testing.T) *testContext { if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { - t.Fatalf("AddAddress sn lladdr1: %v", err) - } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) if err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go new file mode 100644 index 000000000..b07e99dd4 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -0,0 +1,215 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // The least significant 3 bytes are the same as addr2 so both addr2 and + // addr3 will have the same solicited-node address. + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" +) + +// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the +// expected Neighbor Advertisement received count after receiving the packet. +func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + // Receive ICMP packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stats := s.Stats().ICMP.V6PacketsReceived + + if got := stats.NeighborAdvert.Value(); got != want { + t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) + } +} + +// testReceiveICMP tests receiving a UDP packet from src to dst. want is the +// expected UDP received count after receiving the packet. +func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + + // Receive UDP Packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + + // UDP pseudo-header checksum. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) + + // UDP checksum + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stat := s.Stats().UDP.PacketsReceived + + if got := stat.Value(); got != want { + t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) + } +} + +// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and +// UDP packets destined to the IPv6 link-local all-nodes multicast address. +func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { + tests := []struct { + name string + protocolName string + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.ProtocolName6, testReceiveICMP}, + {"UDP", udp.ProtocolName, testReceiveUDP}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should receive a packet destined to the all-nodes + // multicast address. + test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) + }) + } +} + +// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP +// packets destined to the IPv6 solicited-node address of an assigned IPv6 +// address. +func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + tests := []struct { + name string + protocolName string + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.ProtocolName6, testReceiveICMP}, + {"UDP", udp.ProtocolName, testReceiveUDP}, + } + + snmc := header.SolicitedNodeAddr(addr2) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as we haven't added + // those addresses. + test.rxf(t, s, e, addr1, snmc, 0) + + if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + } + + // Should receive a packet destined to the solicited + // node address of addr2/addr3 now that we have added + // added addr2. + test.rxf(t, s, e, addr1, snmc, 1) + + if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have added addr3. + test.rxf(t, s, e, addr1, snmc, 2) + + if err := s.RemoveAddress(1, addr2); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have removed addr2. + test.rxf(t, s, e, addr1, snmc, 3) + + if err := s.RemoveAddress(1, addr3); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as both of them got + // removed. + test.rxf(t, s, e, addr1, snmc, 3) + }) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43719085e..249a19946 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -102,6 +102,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + } + + return nil +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -339,6 +358,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 address, join the solicited-node multicast + // address for a unicast protocolAddress. + if protocolAddress.Protocol == header.IPv6ProtocolNumber && !header.IsV6MulticastAddress(protocolAddress.AddressWithPrefix.Address) { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -467,13 +495,27 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || r.getKind() != permanent { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } r.setKind(permanentExpired) - r.decRefLocked() + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } + + // At this point the endpoint is deleted. + + // If we are removing an IPv6 address, leave the solicited-node + // multicast address for a unicast addr. + if r.protocol == header.IPv6ProtocolNumber && !header.IsV6MulticastAddress(addr) { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -491,6 +533,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -518,6 +567,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -802,11 +858,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a961e8ebe..1fe21b68e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -633,7 +633,7 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, s.nics[id] = n if enabled { - n.attachLinkEndpoint() + return n.enable() } return nil @@ -680,9 +680,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - nic.attachLinkEndpoint() - - return nil + return nic.enable() } // CheckNIC checks if a NIC is usable. -- cgit v1.2.3 From 857940d30d3a8dbb099bad43954fe8062b70461d Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 12 Sep 2019 13:50:58 -0700 Subject: Automated rollback of changelist 268047073 PiperOrigin-RevId: 268757842 --- pkg/tcpip/header/ipv6.go | 7 -- pkg/tcpip/network/ipv6/BUILD | 2 - pkg/tcpip/network/ipv6/icmp_test.go | 6 + pkg/tcpip/network/ipv6/ipv6_test.go | 215 ------------------------------------ pkg/tcpip/stack/nic.go | 69 +----------- pkg/tcpip/stack/stack.go | 6 +- 6 files changed, 15 insertions(+), 290 deletions(-) delete mode 100644 pkg/tcpip/network/ipv6/ipv6_test.go (limited to 'pkg/tcpip/header') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index e606e3463..093850e25 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -76,13 +76,6 @@ const ( // IPv6Version is the version of the ipv6 protocol. IPv6Version = 6 - // IPv6AllNodesMulticastAddress is a link-local multicast group that - // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets - // destined to this address will reach all nodes on a link. - // - // The address is ff02::1. - IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index d02ca0227..c71b69123 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -25,7 +25,6 @@ go_test( size = "small", srcs = [ "icmp_test.go", - "ipv6_test.go", "ndp_test.go", ], embed = [":ipv6"], @@ -37,7 +36,6 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 653d984e9..a6a1a5232 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -222,6 +222,9 @@ func newTestContext(t *testing.T) *testContext { if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } + if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { + t.Fatalf("AddAddress sn lladdr0: %v", err) + } c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) @@ -231,6 +234,9 @@ func newTestContext(t *testing.T) *testContext { if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } + if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { + t.Fatalf("AddAddress sn lladdr1: %v", err) + } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) if err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go deleted file mode 100644 index b07e99dd4..000000000 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // The least significant 3 bytes are the same as addr2 so both addr2 and - // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" -) - -// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the -// expected Neighbor Advertisement received count after receiving the packet. -func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - // Receive ICMP packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) - - stats := s.Stats().ICMP.V6PacketsReceived - - if got := stats.NeighborAdvert.Value(); got != want { - t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) - } -} - -// testReceiveICMP tests receiving a UDP packet from src to dst. want is the -// expected UDP received count after receiving the packet. -func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - - // Receive UDP Packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - - // UDP pseudo-header checksum. - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) - - // UDP checksum - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) - - stat := s.Stats().UDP.PacketsReceived - - if got := stat.Value(); got != want { - t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) - } -} - -// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and -// UDP packets destined to the IPv6 link-local all-nodes multicast address. -func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { - tests := []struct { - name string - protocolName string - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.ProtocolName6, testReceiveICMP}, - {"UDP", udp.ProtocolName, testReceiveUDP}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should receive a packet destined to the all-nodes - // multicast address. - test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) - }) - } -} - -// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP -// packets destined to the IPv6 solicited-node address of an assigned IPv6 -// address. -func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - tests := []struct { - name string - protocolName string - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.ProtocolName6, testReceiveICMP}, - {"UDP", udp.ProtocolName, testReceiveUDP}, - } - - snmc := header.SolicitedNodeAddr(addr2) - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as we haven't added - // those addresses. - test.rxf(t, s, e, addr1, snmc, 0) - - if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) - } - - // Should receive a packet destined to the solicited - // node address of addr2/addr3 now that we have added - // added addr2. - test.rxf(t, s, e, addr1, snmc, 1) - - if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) - } - - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have added addr3. - test.rxf(t, s, e, addr1, snmc, 2) - - if err := s.RemoveAddress(1, addr2); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) - } - - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have removed addr2. - test.rxf(t, s, e, addr1, snmc, 3) - - if err := s.RemoveAddress(1, addr3); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) - } - - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as both of them got - // removed. - test.rxf(t, s, e, addr1, snmc, 3) - }) - } -} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 249a19946..43719085e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -102,25 +102,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } -// enable enables the NIC. enable will attach the link to its LinkEndpoint and -// join the IPv6 All-Nodes Multicast address (ff02::1). -func (n *NIC) enable() *tcpip.Error { - n.attachLinkEndpoint() - - // Join the IPv6 All-Nodes Multicast group if the stack is configured to - // use IPv6. This is required to ensure that this node properly receives - // and responds to the various NDP messages that are destined to the - // all-nodes multicast address. An example is the Neighbor Advertisement - // when we perform Duplicate Address Detection, or Router Advertisement - // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 - // section 4.2 for more information. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { - return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) - } - - return nil -} - // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -358,15 +339,6 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } - // If we are adding an IPv6 address, join the solicited-node multicast - // address for a unicast protocolAddress. - if protocolAddress.Protocol == header.IPv6ProtocolNumber && !header.IsV6MulticastAddress(protocolAddress.AddressWithPrefix.Address) { - snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) - if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { - return nil, err - } - } - n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -495,27 +467,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.endpoints[NetworkEndpointID{addr}] - if !ok || r.getKind() != permanent { + r := n.endpoints[NetworkEndpointID{addr}] + if r == nil || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } r.setKind(permanentExpired) - if !r.decRefLocked() { - // The endpoint still has references to it. - return nil - } - - // At this point the endpoint is deleted. - - // If we are removing an IPv6 address, leave the solicited-node - // multicast address for a unicast addr. - if r.protocol == header.IPv6ProtocolNumber && !header.IsV6MulticastAddress(addr) { - snmc := header.SolicitedNodeAddr(addr) - if err := n.leaveGroupLocked(snmc); err != nil { - return err - } - } + r.decRefLocked() return nil } @@ -533,13 +491,6 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() - return n.joinGroupLocked(protocol, addr) -} - -// joinGroupLocked adds a new endpoint for the given multicast address, if none -// exists yet. Otherwise it just increments its count. n MUST be locked before -// joinGroupLocked is called. -func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -567,13 +518,6 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.leaveGroupLocked(addr) -} - -// leaveGroupLocked decrements the count for the given multicast address, and -// when it reaches zero removes the endpoint for this address. n MUST be locked -// before leaveGroupLocked is called. -func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -858,14 +802,11 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. Returns true if the endpoint was removed. -func (r *referencedNetworkEndpoint) decRefLocked() bool { +// locked. +func (r *referencedNetworkEndpoint) decRefLocked() { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) - return true } - - return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 1fe21b68e..a961e8ebe 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -633,7 +633,7 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, s.nics[id] = n if enabled { - return n.enable() + n.attachLinkEndpoint() } return nil @@ -680,7 +680,9 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - return nic.enable() + nic.attachLinkEndpoint() + + return nil } // CheckNIC checks if a NIC is usable. -- cgit v1.2.3 From df5d377521e625aeb8f4fe18bd1d9974dbf9998c Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Thu, 12 Sep 2019 15:09:01 -0700 Subject: Remove go_test from go_stateify and go_marshal They are no-ops, so the standard rule works fine. PiperOrigin-RevId: 268776264 --- pkg/abi/linux/BUILD | 4 +- pkg/amutex/BUILD | 3 +- pkg/atomicbitops/BUILD | 3 +- pkg/binary/BUILD | 3 +- pkg/bits/BUILD | 3 +- pkg/bpf/BUILD | 4 +- pkg/compressio/BUILD | 3 +- pkg/cpuid/BUILD | 4 +- pkg/eventchannel/BUILD | 3 +- pkg/fd/BUILD | 3 +- pkg/fdchannel/BUILD | 3 +- pkg/flipcall/BUILD | 3 +- pkg/fspath/BUILD | 3 +- pkg/gate/BUILD | 3 +- pkg/ilist/BUILD | 3 +- pkg/linewriter/BUILD | 3 +- pkg/log/BUILD | 3 +- pkg/metric/BUILD | 3 +- pkg/p9/BUILD | 3 +- pkg/p9/p9test/BUILD | 4 +- pkg/procid/BUILD | 3 +- pkg/refs/BUILD | 4 +- pkg/seccomp/BUILD | 4 +- pkg/secio/BUILD | 3 +- pkg/segment/test/BUILD | 3 +- pkg/sentry/control/BUILD | 3 +- pkg/sentry/device/BUILD | 4 +- pkg/sentry/fs/BUILD | 4 +- pkg/sentry/fs/fdpipe/BUILD | 4 +- pkg/sentry/fs/fsutil/BUILD | 4 +- pkg/sentry/fs/gofer/BUILD | 4 +- pkg/sentry/fs/host/BUILD | 4 +- pkg/sentry/fs/lock/BUILD | 4 +- pkg/sentry/fs/proc/BUILD | 4 +- pkg/sentry/fs/proc/seqfile/BUILD | 4 +- pkg/sentry/fs/ramfs/BUILD | 4 +- pkg/sentry/fs/tmpfs/BUILD | 4 +- pkg/sentry/fs/tty/BUILD | 4 +- pkg/sentry/fsimpl/ext/BUILD | 4 +- pkg/sentry/fsimpl/ext/benchmark/BUILD | 2 +- pkg/sentry/fsimpl/ext/disklayout/BUILD | 4 +- pkg/sentry/fsimpl/memfs/BUILD | 3 +- pkg/sentry/fsimpl/proc/BUILD | 3 +- pkg/sentry/hostcpu/BUILD | 3 +- pkg/sentry/kernel/BUILD | 3 +- pkg/sentry/kernel/epoll/BUILD | 4 +- pkg/sentry/kernel/eventfd/BUILD | 4 +- pkg/sentry/kernel/futex/BUILD | 4 +- pkg/sentry/kernel/pipe/BUILD | 4 +- pkg/sentry/kernel/sched/BUILD | 3 +- pkg/sentry/kernel/semaphore/BUILD | 4 +- pkg/sentry/limits/BUILD | 4 +- pkg/sentry/memmap/BUILD | 4 +- pkg/sentry/mm/BUILD | 4 +- pkg/sentry/pgalloc/BUILD | 4 +- pkg/sentry/platform/interrupt/BUILD | 3 +- pkg/sentry/platform/kvm/BUILD | 3 +- pkg/sentry/platform/ring0/pagetables/BUILD | 3 +- pkg/sentry/platform/safecopy/BUILD | 3 +- pkg/sentry/safemem/BUILD | 3 +- pkg/sentry/socket/netlink/port/BUILD | 4 +- pkg/sentry/time/BUILD | 3 +- pkg/sentry/usermem/BUILD | 4 +- pkg/sentry/vfs/BUILD | 3 +- pkg/sleep/BUILD | 3 +- pkg/state/BUILD | 3 +- pkg/state/statefile/BUILD | 3 +- pkg/syserror/BUILD | 3 +- pkg/tcpip/BUILD | 4 +- pkg/tcpip/adapters/gonet/BUILD | 3 +- pkg/tcpip/buffer/BUILD | 4 +- pkg/tcpip/hash/jenkins/BUILD | 3 +- pkg/tcpip/header/BUILD | 4 +- pkg/tcpip/link/fdbased/BUILD | 3 +- pkg/tcpip/link/muxed/BUILD | 3 +- pkg/tcpip/link/sharedmem/BUILD | 3 +- pkg/tcpip/link/sharedmem/pipe/BUILD | 3 +- pkg/tcpip/link/sharedmem/queue/BUILD | 3 +- pkg/tcpip/link/waitable/BUILD | 3 +- pkg/tcpip/network/BUILD | 2 +- pkg/tcpip/network/arp/BUILD | 3 +- pkg/tcpip/network/fragmentation/BUILD | 4 +- pkg/tcpip/network/ipv4/BUILD | 3 +- pkg/tcpip/network/ipv6/BUILD | 3 +- pkg/tcpip/ports/BUILD | 3 +- pkg/tcpip/stack/BUILD | 4 +- pkg/tcpip/transport/tcp/BUILD | 4 +- pkg/tcpip/transport/tcpconntrack/BUILD | 3 +- pkg/tcpip/transport/udp/BUILD | 4 +- pkg/tmutex/BUILD | 3 +- pkg/unet/BUILD | 3 +- pkg/urpc/BUILD | 3 +- pkg/waiter/BUILD | 4 +- tools/go_marshal/defs.bzl | 6 --- tools/go_marshal/test/BUILD | 4 +- tools/go_stateify/defs.bzl | 65 ++++++++++++++++++++---------- 96 files changed, 268 insertions(+), 123 deletions(-) (limited to 'pkg/tcpip/header') diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index ba233b93f..39c92bb33 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -2,9 +2,11 @@ # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix # when the host OS may not be Linux. +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "linux", diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 39d253b98..6bc486b62 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 47ab65346..5f59866fa 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD index 09d6c2c1f..543fb54bf 100644 --- a/pkg/binary/BUILD +++ b/pkg/binary/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 0c2dde4f8..51967b811 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index b692aa3b1..8d31e068c 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "bpf", diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index cdec96df1..a0b21d4bd 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index 830e19e07..32422f9e2 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "cpuid", diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 9961baaa9..71f2abc83 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index 785c685a0..afa8f7659 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index e54e7371c..56495cbd9 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index c1e078c7c..5643d5f26 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index 11716af81..0c5f50397 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index e6a8dbd02..4b9321711 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD index 8f3defa25..34d2673ef 100644 --- a/pkg/ilist/BUILD +++ b/pkg/ilist/BUILD @@ -1,5 +1,6 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index c8e923a74..a5d980d14 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 12615240c..fc5f5779b 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 3b8a691f4..842788179 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index c6737bf97..6bc4d3bc7 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 6e939a49a..1d34181e0 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") package(licenses = ["notice"]) diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 697e7a2f4..078f084b2 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 9c08452fc..827385139 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "weak_ref_list", diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index d1024e49d..af94e944d 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test") package(licenses = ["notice"]) diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD index f38fb39f3..22abdc69f 100644 --- a/pkg/secio/BUILD +++ b/pkg/secio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index 694486296..12d7c77d2 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:private"], diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index bf802d1b6..5522cecd0 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 7e8918722..0c86197f7 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "device", diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index d7259b47b..3119a61b6 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "fs", diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index bf00b9c09..b9bd9ed17 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "fdpipe", diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 6499f87ac..b4ac83dc4 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "dirty_set_impl", diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 6b993928c..2b71ca0e1 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "gofer", diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index b1080fb1a..3e532332e 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "host", diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 08d7c0c57..5a7a5b8cd 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "lock_range", diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index c7599d1f6..1c93e8886 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "proc", diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 20c3eefc8..76433c7d0 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "seqfile", diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 516efcc4c..d0f351e5a 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "ramfs", diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 8f7eb5757..11b680929 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tmpfs", diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5e9327aec..d799de748 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tty", diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 9e8ebb907..b0c286b7a 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") go_template_instance( diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 9fddb4c4c..bfc46dfa6 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 907d35b7e..2d50e30aa 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "disklayout", diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index d2450e810..7e364c5fd 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 3d8a4deaf..ade6ac946 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD index f989f2f8b..d4a420e60 100644 --- a/pkg/sentry/hostcpu/BUILD +++ b/pkg/sentry/hostcpu/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e61d39c82..e964a991b 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,9 +1,10 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "pending_signals_list", diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index f46c43128..65427b112 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "epoll_list", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index 1c5f979d4..983ca67ed 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "eventfd", diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 6a31dc044..41f44999c 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "atomicptr_bucket", diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 4d15cca85..2ce8952e2 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "buffer_list", diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD index 1725b8562..98ea7a0d8 100644 --- a/pkg/sentry/kernel/sched/BUILD +++ b/pkg/sentry/kernel/sched/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 36edf10f3..80e5e5da3 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "waiter_list", diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 40025d62d..59649c770 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "limits", diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 29c14ec56..9687e7e76 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "mappable_range", diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 072745a08..b35c8c673 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "file_refcount_set", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 858f895f2..3fd904c67 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "evictable_range", diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index eeb634644..b6d008dbe 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index fe979dccf..31fa48ec5 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 3b95af617..ea090b686 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD index 924d8a6d6..6769cd0a5 100644 --- a/pkg/sentry/platform/safecopy/BUILD +++ b/pkg/sentry/platform/safecopy/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD index fd6dc8e6e..884020f7b 100644 --- a/pkg/sentry/safemem/BUILD +++ b/pkg/sentry/safemem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 9e2e12799..445080aa4 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "port", diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 8aa6a3017..beb43ba13 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index a5b4206bb..cc5d25762 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "addr_range", diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 0f247bf77..eff4b44f6 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index 00665c939..bdca80d37 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/state/BUILD b/pkg/state/BUILD index c0f3c658d..329904457 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index e70f4a79f..8a865d229 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index b149f9e02..bd3f9fd28 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index df37c7d5a..3fd9e3134 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tcpip", diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 0d2637ee4..78df5a0b1 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 3301967fb..b4e8d6810 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "buffer", diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 29b30be9c..0c5c20cea 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 76ef02f13..b558350c3 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "header", diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 74fbbb896..8fa9e3984 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index ea12ef1ac..1bab380b0 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index f2998aa98..0a5ea3dc4 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 94725cb11..330ed5e94 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 160a8f864..de1ce043d 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 2597d4b3e..0746dc8ec 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index f36f49453..9d16ff8c9 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d95d44f56..df0d3a8c0 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 118bfc763..c5c7aad86 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "reassembler_list", diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index be84fa63d..58e537aad 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index c71b69123..a471abbfb 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 989058413..11efb4e44 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 788de3dfe..28c49e8ff 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "linkaddrentry_list", diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 1ee1a53f8..39a839ab7 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "tcp_segment_list", diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 4bec48c0f..43fcc27f0 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index ac2666f69..c1ca22b35 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "udp_packet_list", diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 98d51cc69..6afdb29b7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index cbd92fc05..8f6f180e5 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b7f505a84..b6bbb0ea2 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 9173dfd0f..8dc88becb 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "waiter_list", diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index 60a992b7f..c32eb559f 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -150,9 +150,3 @@ def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs): ], **kwargs ) - -def go_test(**kwargs): - """Wraps the standard go_test.""" - _go_test( - **kwargs - ) diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index 947011414..fa82f8e9b 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_marshal:defs.bzl", "go_library", "go_test") +load("//tools/go_marshal:defs.bzl", "go_library") package_group( name = "gomarshal_test", diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index aeba197e2..3ce36c1c8 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -35,7 +35,7 @@ go_library( ) """ -load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test") +load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library") def _go_stateify_impl(ctx): """Implementation for the stateify tool.""" @@ -60,28 +60,57 @@ def _go_stateify_impl(ctx): executable = ctx.executable._tool, ) -# Generates save and restore logic from a set of Go files. -# -# Args: -# name: the name of the rule. -# srcs: the input source files. These files should include all structs in the package that need to be saved. -# imports: an optional list of extra non-aliased, Go-style absolute import paths. -# out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library. -# package: the package name for the input sources. go_stateify = rule( implementation = _go_stateify_impl, + doc = "Generates save and restore logic from a set of Go files.", attrs = { - "srcs": attr.label_list(mandatory = True, allow_files = True), - "imports": attr.string_list(mandatory = False), - "package": attr.string(mandatory = True), - "out": attr.output(mandatory = True), - "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_stateify:stateify")), + "srcs": attr.label_list( + doc = """ +The input source files. These files should include all structs in the package +that need to be saved. +""", + mandatory = True, + allow_files = True, + ), + "imports": attr.string_list( + doc = """ +An optional list of extra non-aliased, Go-style absolute import paths required +for statified types. +""", + mandatory = False, + ), + "package": attr.string( + doc = "The package name for the input sources.", + mandatory = True, + ), + "out": attr.output( + doc = """ +The name of the generated file output. This must not conflict with any other +files and must be added to the srcs of the relevant go_library. +""", + mandatory = True, + ), + "_tool": attr.label( + executable = True, + cfg = "host", + default = Label("//tools/go_stateify:stateify"), + ), "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), }, ) def go_library(name, srcs, deps = [], imports = [], **kwargs): - """wraps the standard go_library and does stateification.""" + """Standard go_library wrapped which generates state source files. + + Args: + name: the name of the go_library rule. + srcs: sources of the go_library. Each will be processed for stateify + annotations. + deps: dependencies for the go_library. + imports: an optional list of extra non-aliased, Go-style absolute import + paths required for stateified types. + **kwargs: passed to go_library. + """ if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs: # Only do stateification for non-state packages without manual autogen. go_stateify( @@ -105,9 +134,3 @@ def go_library(name, srcs, deps = [], imports = [], **kwargs): deps = all_deps, **kwargs ) - -def go_test(**kwargs): - """Wraps the standard go_test.""" - _go_test( - **kwargs - ) -- cgit v1.2.3 From 7c6ab6a219f37a1d4c18ced4a602458fcf363f85 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 12 Sep 2019 17:42:14 -0700 Subject: Implement splice methods for pipes and sockets. This also allows the tee(2) implementation to be enabled, since dup can now be properly supported via WriteTo. Note that this change necessitated some minor restructoring with the fs.FileOperations splice methods. If the *fs.File is passed through directly, then only public API methods are accessible, which will deadlock immediately since the locking is already done by fs.Splice. Instead, we pass through an abstract io.Reader or io.Writer, which elide locks and use the underlying fs.FileOperations directly. PiperOrigin-RevId: 268805207 --- pkg/sentry/fs/file.go | 23 +++- pkg/sentry/fs/file_operations.go | 9 +- pkg/sentry/fs/file_overlay.go | 9 +- pkg/sentry/fs/fsutil/file.go | 6 +- pkg/sentry/fs/inotify.go | 5 +- pkg/sentry/fs/splice.go | 162 +++++++++++++------------- pkg/sentry/kernel/pipe/buffer.go | 25 ++++ pkg/sentry/kernel/pipe/pipe.go | 82 +++++++++++--- pkg/sentry/kernel/pipe/reader_writer.go | 76 ++++++++++++- pkg/sentry/socket/epsocket/epsocket.go | 134 +++++++++++++++++++--- pkg/sentry/syscalls/linux/linux64.go | 4 +- pkg/sentry/syscalls/linux/sys_splice.go | 86 +++++++------- pkg/tcpip/header/udp.go | 5 + pkg/tcpip/stack/transport_test.go | 4 +- pkg/tcpip/tcpip.go | 48 ++++---- pkg/tcpip/transport/icmp/endpoint.go | 4 +- pkg/tcpip/transport/raw/endpoint.go | 7 +- pkg/tcpip/transport/tcp/endpoint.go | 68 ++++++----- pkg/tcpip/transport/udp/endpoint.go | 14 +-- test/syscalls/linux/BUILD | 3 + test/syscalls/linux/pipe.cc | 14 +++ test/syscalls/linux/sendfile.cc | 69 ++++++++++++ test/syscalls/linux/splice.cc | 194 +++++++++++++++++++++++++------- 23 files changed, 770 insertions(+), 281 deletions(-) (limited to 'pkg/tcpip/header') diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index bb8117f89..c0a6e884b 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -515,6 +515,11 @@ type lockedReader struct { // File is the file to read from. File *File + + // Offset is the offset to start at. + // + // This applies only to Read, not ReadAt. + Offset int64 } // Read implements io.Reader.Read. @@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) { if r.Ctx.Interrupted() { return 0, syserror.ErrInterrupted } - n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset) + n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset) + r.Offset += n return int(n), err } @@ -544,11 +550,21 @@ type lockedWriter struct { // File is the file to write to. File *File + + // Offset is the offset to start at. + // + // This applies only to Write, not WriteAt. + Offset int64 } // Write implements io.Writer.Write. func (w *lockedWriter) Write(buf []byte) (int, error) { - return w.WriteAt(buf, w.File.offset) + if w.Ctx.Interrupted() { + return 0, syserror.ErrInterrupted + } + n, err := w.WriteAt(buf, w.Offset) + w.Offset += int64(n) + return int(n), err } // WriteAt implements io.Writer.WriteAt. @@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { // io.Copy, since our own Write interface does not have this same // contract. Enforce that here. for written < len(buf) { + if w.Ctx.Interrupted() { + return written, syserror.ErrInterrupted + } var n int64 n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written)) if n > 0 { diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index d86f5bf45..b88303f17 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -15,6 +15,8 @@ package fs import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -105,8 +107,11 @@ type FileOperations interface { // on the destination, following by a buffered copy with standard Read // and Write operations. // + // If dup is set, the data should be duplicated into the destination + // and retained. + // // The same preconditions as Read apply. - WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error) + WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error) // Write writes src to file at offset and returns the number of bytes // written which must be greater than or equal to 0. Like Read, file @@ -126,7 +131,7 @@ type FileOperations interface { // source. See WriteTo for details regarding how this is called. // // The same preconditions as Write apply; FileFlags.Write must be set. - ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error) + ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error) // Fsync writes buffered modifications of file and/or flushes in-flight // operations to backing storage based on syncType. The range to sync is diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 9820f0b13..225e40186 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "gvisor.dev/gvisor/pkg/refs" @@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme } // WriteTo implements FileOperations.WriteTo. -func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) { err = f.onTop(ctx, file, func(file *File, ops FileOperations) error { - n, err = ops.WriteTo(ctx, file, dst, opts) + n, err = ops.WriteTo(ctx, file, dst, count, dup) return err // Will overwrite itself. }) return @@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm } // ReadFrom implements FileOperations.ReadFrom. -func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) { // See above; f.upper must be non-nil. - return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts) + return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count) } // Fsync implements FileOperations.Fsync. diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 626b9126a..fc5b3b1a1 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -15,6 +15,8 @@ package fsutil import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu type FileNoSplice struct{} // WriteTo implements fs.FileOperations.WriteTo. -func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } // ReadFrom implements fs.FileOperations.ReadFrom. -func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index c7f4e2d13..ba3e0233d 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "sync/atomic" @@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i } // WriteTo implements FileOperations.WriteTo. -func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } @@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error { } // ReadFrom implements FileOperations.ReadFrom. -func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index eed1c2854..b03b7f836 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -18,7 +18,6 @@ import ( "io" "sync/atomic" - "gvisor.dev/gvisor/pkg/secio" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserror" ) @@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } // Check whether or not the objects being sliced are stream-oriented - // (i.e. pipes or sockets). If yes, we elide checks and offset locks. - srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr) - dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr) + // (i.e. pipes or sockets). For all stream-oriented files and files + // where a specific offiset is not request, we acquire the file mutex. + // This has two important side effects. First, it provides the standard + // protection against concurrent writes that would mutate the offset. + // Second, it prevents Splice deadlocks. Only internal anonymous files + // implement the ReadFrom and WriteTo methods directly, and since such + // anonymous files are referred to by a unique fs.File object, we know + // that the file mutex takes strict precedence over internal locks. + // Since we enforce lock ordering here, we can't deadlock by using + // using a file in two different splice operations simultaneously. + srcPipe := !IsRegular(src.Dirent.Inode.StableAttr) + dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr) + dstAppend := !dstPipe && dst.Flags().Append + srcLock := srcPipe || !opts.SrcOffset + dstLock := dstPipe || !opts.DstOffset || dstAppend - if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset { + switch { + case srcLock && dstLock: switch { case dst.UniqueID < src.UniqueID: // Acquire dst first. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() if !src.mu.Lock(ctx) { + dst.mu.Unlock() return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() case dst.UniqueID > src.UniqueID: // Acquire src first. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() if !dst.mu.Lock(ctx) { + src.mu.Unlock() return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() case dst.UniqueID == src.UniqueID: // Acquire only one lock; it's the same file. This is a // bit of a edge case, but presumably it's possible. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() + srcLock = false // Only need one unlock. } // Use both offsets (locked). opts.DstStart = dst.offset opts.SrcStart = src.offset - } else if !dstPipe && !opts.DstOffset { + case dstLock: // Acquire only dst. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() opts.DstStart = dst.offset // Safe: locked. - } else if !srcPipe && !opts.SrcOffset { + case srcLock: // Acquire only src. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() opts.SrcStart = src.offset // Safe: locked. } - // Check append-only mode and the limit. - if !dstPipe { + var err error + if dstAppend { unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append) defer unlock() - if dst.Flags().Append { - if opts.DstOffset { - // We need to acquire the lock. - if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted - } - defer dst.mu.Unlock() - } - // Figure out the appropriate offset to use. - if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil { - return 0, err - } - } + // Figure out the appropriate offset to use. + err = dst.offsetForAppend(ctx, &opts.DstStart) + } + if err == nil && !dstPipe { // Enforce file limits. limit, ok := dst.checkLimit(ctx, opts.DstStart) switch { case ok && limit == 0: - return 0, syserror.ErrExceedsFileSizeLimit + err = syserror.ErrExceedsFileSizeLimit case ok && limit < opts.Length: opts.Length = limit // Cap the write. } } + if err != nil { + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return 0, err + } - // Attempt to do a WriteTo; this is likely the most efficient. - // - // The underlying implementation may be able to donate buffers. - newOpts := SpliceOpts{ - Length: opts.Length, - SrcStart: opts.SrcStart, - SrcOffset: !srcPipe, - Dup: opts.Dup, - DstStart: opts.DstStart, - DstOffset: !dstPipe, + // Construct readers and writers for the splice. This is used to + // provide a safer locking path for the WriteTo/ReadFrom operations + // (since they will otherwise go through public interface methods which + // conflict with locking done above), and simplifies the fallback path. + w := &lockedWriter{ + Ctx: ctx, + File: dst, + Offset: opts.DstStart, } - n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts) - if n == 0 && err != nil { - // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also - // be more efficient than a copy if buffers are cached or readily - // available. (It's unlikely that they can actually be donate - n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts) + r := &lockedReader{ + Ctx: ctx, + File: src, + Offset: opts.SrcStart, } - if n == 0 && err != nil { - // If we've failed up to here, and at least one of the sources - // is a pipe or socket, then we can't properly support dup. - // Return an error indicating that this operation is not - // supported. - if (srcPipe || dstPipe) && newOpts.Dup { - return 0, syserror.EINVAL - } - // We failed to splice the files. But that's fine; we just fall - // back to a slow path in this case. This copies without doing - // any mode changes, so should still be more efficient. - var ( - r io.Reader - w io.Writer - ) - fw := &lockedWriter{ - Ctx: ctx, - File: dst, - } - if newOpts.DstOffset { - // Use the provided offset. - w = secio.NewOffsetWriter(fw, newOpts.DstStart) - } else { - // Writes will proceed with no offset. - w = fw - } - fr := &lockedReader{ - Ctx: ctx, - File: src, - } - if newOpts.SrcOffset { - // Limit to the given offset and length. - r = io.NewSectionReader(fr, opts.SrcStart, opts.Length) - } else { - // Limit just to the given length. - r = &io.LimitedReader{fr, opts.Length} - } + // Attempt to do a WriteTo; this is likely the most efficient. + n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) + if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup { + // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be + // more efficient than a copy if buffers are cached or readily + // available. (It's unlikely that they can actually be donated). + n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length) + } - // Copy between the two. - n, err = io.Copy(w, r) + // Support one last fallback option, but only if at least one of + // the source and destination are regular files. This is because + // if we block at some point, we could lose data. If the source is + // not a pipe then reading is not destructive; if the destination + // is a regular file, then it is guaranteed not to block writing. + if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup && (!dstPipe || !srcPipe) { + // Fallback to an in-kernel copy. + n, err = io.Copy(w, &io.LimitedReader{ + R: r, + N: opts.Length, + }) } // Update offsets, if required. @@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } } + // Drop locks. + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return n, err } diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 69ef2a720..95bee2d37 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" @@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } +// WriteFromReader writes to the buffer from an io.Reader. +func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) { + dst := b.data[b.write:] + if count < int64(len(dst)) { + dst = b.data[b.write:][:count] + } + n, err := r.Read(dst) + b.write += n + return int64(n), err +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write])) @@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return n, err } +// ReadToWriter reads from the buffer into an io.Writer. +func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) { + src := b.data[b.read:b.write] + if count < int64(len(src)) { + src = b.data[b.read:][:count] + } + n, err := w.Write(src) + if !dup { + b.read += n + } + return int64(n), err +} + // bufferPool is a pool for buffers. var bufferPool = sync.Pool{ New: func() interface{} { diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 247e2928e..93b50669f 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F } } +type readOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit limits subsequence reads. + limit func(int64) + + // read performs the actual read operation. + read func(*buffer) (int64, error) +} + // read reads data from the pipe into dst and returns the number of bytes // read, or returns ErrWouldBlock if the pipe is empty. // // Precondition: this pipe must have readers. -func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) { +func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { // Don't block for a zero-length read even if the pipe is empty. - if dst.NumBytes() == 0 { + if ops.left() == 0 { return 0, nil } @@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Limit how much we consume. - if dst.NumBytes() > p.size { - dst = dst.TakeFirst64(p.size) + if ops.left() > p.size { + ops.limit(p.size) } done := int64(0) - for dst.NumBytes() > 0 { + for ops.left() > 0 { // Pop the first buffer. first := p.data.Front() if first == nil { @@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Copy user data. - n, err := dst.CopyOutFrom(ctx, first) + n, err := ops.read(first) done += int64(n) p.size -= n - dst = dst.DropFirst64(n) // Empty buffer? if first.Empty() { @@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) return done, nil } +// dup duplicates all data from this pipe into the given writer. +// +// There is no blocking behavior implemented here. The writer may propagate +// some blocking error. All the writes must be complete writes. +func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Is the pipe empty? + if p.size == 0 { + if !p.HasWriters() { + // See above. + return 0, nil + } + return 0, syserror.ErrWouldBlock + } + + // Limit how much we consume. + if ops.left() > p.size { + ops.limit(p.size) + } + + done := int64(0) + for buf := p.data.Front(); buf != nil; buf = buf.Next() { + n, err := ops.read(buf) + done += n + if err != nil { + return done, err + } + } + + return done, nil +} + +type writeOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit should limit subsequent writes. + limit func(int64) + + // write should write to the provided buffer. + write func(*buffer) (int64, error) +} + // write writes data from sv into the pipe and returns the number of bytes // written. If no bytes are written because the pipe is full (or has less than // atomicIOBytes free capacity), write returns ErrWouldBlock. // // Precondition: this pipe must have writers. -func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) { +func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() @@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. - wanted := src.NumBytes() + wanted := ops.left() if avail := p.max - p.size; wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } - // Limit to the available capacity. - src = src.TakeFirst64(avail) + ops.limit(avail) } done := int64(0) - for src.NumBytes() > 0 { + for ops.left() > 0 { // Need a new buffer? last := p.data.Back() if last == nil || last.Full() { @@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) } // Copy user data. - n, err := src.CopyInTo(ctx, last) + n, err := ops.write(last) done += int64(n) p.size += n - src = src.DropFirst64(n) // Handle errors. if err != nil { diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index f69dbf27b..7c307f013 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "math" "syscall" @@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() { // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, dst) + n, err := rw.Pipe.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return rw.Pipe.dup(ctx, ops) + } + n, err := rw.Pipe.read(ctx, ops) if n > 0 { rw.Pipe.Notify(waiter.EventOut) } @@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, src) + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) if n > 0 { rw.Pipe.Notify(waiter.EventIn) } diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 0e37ce61b..3e05e40fe 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -26,6 +26,7 @@ package epsocket import ( "bytes" + "io" "math" "reflect" "sync" @@ -227,7 +228,6 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout *waiter.Queue @@ -412,17 +412,58 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return int64(n), nil } -// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand -// based on the requested size. +// WriteTo implements fs.FileOperations.WriteTo. +func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { + s.readMu.Lock() + defer s.readMu.Unlock() + + // Copy as much data as possible. + done := int64(0) + for count > 0 { + // This may return a blocking error. + if err := s.fetchReadView(); err != nil { + return done, err.ToError() + } + + // Write to the underlying file. + n, err := dst.Write(s.readView) + done += int64(n) + count -= int64(n) + if dup { + // That's all we support for dup. This is generally + // supported by any Linux system calls, but the + // expectation is that now a caller will call read to + // actually remove these bytes from the socket. + return done, nil + } + + // Drop that part of the view. + s.readView.TrimFront(n) + if err != nil { + return done, err + } + } + + return done, nil +} + +// ioSequencePayload implements tcpip.Payload. +// +// t copies user memory bytes on demand based on the requested size. type ioSequencePayload struct { ctx context.Context src usermem.IOSequence } -// Get implements tcpip.Payload. -func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { - if size > i.Size() { - size = i.Size() +// FullPayload implements tcpip.Payloader.FullPayload +func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { + return i.Payload(int(i.src.NumBytes())) +} + +// Payload implements tcpip.Payloader.Payload. +func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { + if max := int(i.src.NumBytes()); size > max { + size = max } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { @@ -431,11 +472,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { return v, nil } -// Size implements tcpip.Payload. -func (i *ioSequencePayload) Size() int { - return int(i.src.NumBytes()) -} - // DropFirst drops the first n bytes from underlying src. func (i *ioSequencePayload) DropFirst(n int) { i.src = i.src.DropFirst(int(n)) @@ -469,6 +505,76 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), nil } +// readerPayload implements tcpip.Payloader. +// +// It allocates a view and reads from a reader on-demand, based on available +// capacity in the endpoint. +type readerPayload struct { + ctx context.Context + r io.Reader + count int64 + err error +} + +// FullPayload implements tcpip.Payloader.FullPayload. +func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { + return r.Payload(int(r.count)) +} + +// Payload implements tcpip.Payloader.Payload. +func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { + if size > int(r.count) { + size = int(r.count) + } + v := buffer.NewView(size) + n, err := r.r.Read(v) + if n > 0 { + // We ignore the error here. It may re-occur on subsequent + // reads, but for now we can enqueue some amount of data. + r.count -= int64(n) + return v[:n], nil + } + if err == syserror.ErrWouldBlock { + return nil, tcpip.ErrWouldBlock + } else if err != nil { + r.err = err // Save for propation. + return nil, tcpip.ErrBadAddress + } + + // There is no data and no error. Return an error, which will propagate + // r.err, which will be nil. This is the desired result: (0, nil). + return nil, tcpip.ErrBadAddress +} + +// ReadFrom implements fs.FileOperations.ReadFrom. +func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + f := &readerPayload{ctx: ctx, r: r, count: count} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) + } + if err == tcpip.ErrWouldBlock { + return n, syserror.ErrWouldBlock + } else if err != nil { + return int64(n), f.err // Propagate error. + } + + return int64(n), nil +} + // Readiness returns a mask of ready events for socket s. func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) @@ -2060,7 +2166,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= int64(v.Size()) || dontWait) { + if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. return int(n), nil } @@ -2085,7 +2191,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.TranslateNetstackError(err) } - if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ed996ba51..150999fb8 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -320,8 +320,8 @@ var AMD64 = &kernel.SyscallTable{ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 8a98fedcb..f0a292f2f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB total int64 n int64 err error - ch chan struct{} - inW bool - outW bool + inCh chan struct{} + outCh chan struct{} ) for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) @@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB break } - // Are we a registered waiter? - if ch == nil { - ch = make(chan struct{}, 1) - } - if !inW && !inFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - inFile.EventRegister(&w, EventMaskRead) - defer inFile.EventUnregister(&w) - inW = true // Registered. - } else if !outW && !outFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - outFile.EventRegister(&w, EventMaskWrite) - defer outFile.EventUnregister(&w) - outW = true // Registered. - } - - // Was anything registered? If no, everything is non-blocking. - if !inW && !outW { - break - } - - if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) { - // Something became ready, try again without blocking. - continue + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(EventMaskRead) == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, EventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } } - - // Block until there's data. - if err = t.Block(ch); err != nil { - break + if outFile.Readiness(EventMaskWrite) == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } } } @@ -149,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Length: count, SrcOffset: true, SrcStart: offset, - }, false) + }, outFile.Flags().NonBlocking) // Copy out the new offset. if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { @@ -159,7 +156,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Send data using splice. n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, - }, false) + }, outFile.Flags().NonBlocking) } // We can only pass a single file to handleIOError, so pick inFile @@ -181,12 +178,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. Note that unlike in Linux, this - // flag is applied consistently. We will have either fully blocking or - // non-blocking behavior below, regardless of the underlying files - // being spliced to. It's unclear if this is a bug or not yet. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -200,6 +191,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } defer inFile.DecRef() + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Construct our options. // // Note that exactly one of the underlying buffers must be a pipe. We @@ -257,7 +255,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Splice data. - n, err := doSplice(t, outFile, inFile, opts, nonBlocking) + n, err := doSplice(t, outFile, inFile, opts, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile) @@ -275,9 +273,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -301,11 +296,14 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } + // The operation is non-blocking if anything is non-blocking. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Splice data. n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, Dup: true, - }, nonBlocking) + }, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index c1f454805..74412c894 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -27,6 +27,11 @@ const ( udpChecksum = 6 ) +const ( + // UDPMaximumPacketSize is the largest possible UDP packet. + UDPMaximumPacketSize = 0xffff +) + // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 87d1e0d0d..847d02982 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ebf8a2d04..2534069ab 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -261,31 +261,34 @@ type FullAddress struct { Port uint16 } -// Payload provides an interface around data that is being sent to an endpoint. -// This allows the endpoint to request the amount of data it needs based on -// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data. -type Payload interface { - // Get returns a slice containing exactly 'min(size, p.Size())' bytes. - Get(size int) ([]byte, *Error) - - // Size returns the payload size. - Size() int +// Payloader is an interface that provides data. +// +// This interface allows the endpoint to request the amount of data it needs +// based on internal buffers without exposing them. +type Payloader interface { + // FullPayload returns all available bytes. + FullPayload() ([]byte, *Error) + + // Payload returns a slice containing at most size bytes. + Payload(size int) ([]byte, *Error) } -// SlicePayload implements Payload on top of slices for convenience. +// SlicePayload implements Payloader for slices. +// +// This is typically used for tests. type SlicePayload []byte -// Get implements Payload. -func (s SlicePayload) Get(size int) ([]byte, *Error) { - if size > s.Size() { - size = s.Size() - } - return s[:size], nil +// FullPayload implements Payloader.FullPayload. +func (s SlicePayload) FullPayload() ([]byte, *Error) { + return s, nil } -// Size implements Payload. -func (s SlicePayload) Size() int { - return len(s) +// Payload implements Payloader.Payload. +func (s SlicePayload) Payload(size int) ([]byte, *Error) { + if size > len(s) { + size = len(s) + } + return s[:size], nil } // A ControlMessages contains socket control messages for IP sockets. @@ -338,7 +341,7 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // @@ -432,6 +435,11 @@ type WriteOptions struct { // EndOfRecord has the same semantics as Linux's MSG_EOR. EndOfRecord bool + + // Atomic means that all data fetched from Payloader must be written to the + // endpoint. If Atomic is false, then data fetched from the Payloader may be + // discarded if available endpoint buffer space is unsufficient. + Atomic bool } // SockOpt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e1f622af6..3db060384 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -204,7 +204,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -289,7 +289,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 13e17e2a6..cf1c5c433 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -207,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -220,9 +220,8 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 return 0, nil, tcpip.ErrInvalidEndpointState } - payloadBytes, err := payload.Get(payload.Size()) + payloadBytes, err := p.FullPayload() if err != nil { - ep.mu.RUnlock() return 0, nil, err } @@ -230,7 +229,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 // destination address, route using that address. if !ep.associated { ip := header.IPv4(payloadBytes) - if !ip.IsValid(payload.Size()) { + if !ip.IsValid(len(payloadBytes)) { ep.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ac927569a..dd931f88c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -806,7 +806,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -821,47 +821,52 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, err } - e.sndBufMu.Unlock() - e.mu.RUnlock() - - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + e.mu.RUnlock() } - // Copy in memory without holding sndBufMu so that worker goroutine can - // make progress independent of this operation. - v, perr := p.Get(avail) - if perr != nil { + // Fetch data. + v, perr := p.Payload(avail) + if perr != nil || len(v) == 0 { + if opts.Atomic { // See above. + e.sndBufMu.Unlock() + e.mu.RUnlock() + } + // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - e.mu.RLock() - e.sndBufMu.Lock() + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a - // write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - return 0, nil, err - } + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } - // Discard any excess data copied in due to avail being reduced due to a - // simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } // Add data to the send queue. - l := len(v) s := newSegmentFromView(&e.route, e.id, v) - e.sndBufUsed += l - e.sndBufInQueue += seqnum.Size(l) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() // Release the endpoint lock to prevent deadlocks due to lock // order inversion when acquiring workMu. @@ -875,7 +880,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return int64(l), nil, nil + + return int64(len(v)), nil, nil } // Peek reads data without consuming it from the endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index dccb9a7eb..6ac7c067a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "math" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -277,17 +276,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - if p.Size() > math.MaxUint16 { - // Payload can't possibly fit in a packet. - return 0, nil, tcpip.ErrMessageTooLong - } - to := opts.To e.mu.RLock() @@ -370,10 +364,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } ttl := route.DefaultTTL() if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 34057e3d0..df00d2c14 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1867,7 +1867,9 @@ cc_binary( "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "//test/util:thread_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) @@ -1901,6 +1903,7 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 65afb90f3..10e2a6dfc 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -168,6 +168,20 @@ TEST_P(PipeTest, Write) { EXPECT_EQ(wbuf, rbuf); } +TEST_P(PipeTest, WritePage) { + SKIP_IF(!CreateBlocking()); + + std::vector wbuf(kPageSize); + RandomizeBuffer(wbuf.data(), wbuf.size()); + std::vector rbuf(wbuf.size()); + + ASSERT_THAT(write(wfd_.get(), wbuf.data(), wbuf.size()), + SyscallSucceedsWithValue(wbuf.size())); + ASSERT_THAT(read(rfd_.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), wbuf.size()), 0); +} + TEST_P(PipeTest, NonBlocking) { SKIP_IF(!CreateNonBlocking()); diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 9167ab066..4502e7fb4 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -19,9 +19,12 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -442,6 +445,72 @@ TEST(SendFileTest, SendToNotARegularFile) { EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0), SyscallFailsWithErrno(EINVAL)); } + +TEST(SendFileTest, SendPipeWouldBlock) { + // Create temp file. + constexpr char kData[] = + "The fool doth think he is wise, but the wise man knows himself to be a " + "fool."; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill up the pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector buf(2 * pipe_size); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST(SendFileTest, SendPipeBlocks) { + // Create temp file. + constexpr char kData[] = + "The fault, dear Brutus, is not in our stars, but in ourselves."; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill up the pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector buf(pipe_size); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + }); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize), + SyscallSucceedsWithValue(kDataSize)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index e25f264f6..85232cb1f 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -14,12 +14,16 @@ #include #include +#include #include +#include #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -36,23 +40,23 @@ TEST(SpliceTest, TwoRegularFiles) { const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); // Open the input file as read only. - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); // Open the output file as write only. - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); // Verify that it is rejected as expected; regardless of offsets. loff_t in_offset = 0; loff_t out_offset = 0; - EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), &out_offset, 1, 0), + EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), &out_offset, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), &out_offset, 1, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), &out_offset, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), nullptr, 1, 0), + EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), nullptr, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), nullptr, 1, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), nullptr, 1, 0), SyscallFailsWithErrno(EINVAL)); } @@ -75,8 +79,6 @@ TEST(SpliceTest, SamePipe) { } TEST(TeeTest, SamePipe) { - SKIP_IF(IsRunningOnGvisor()); - // Create a new pipe. int fds[2]; ASSERT_THAT(pipe(fds), SyscallSucceeds()); @@ -95,11 +97,9 @@ TEST(TeeTest, SamePipe) { } TEST(TeeTest, RegularFile) { - SKIP_IF(IsRunningOnGvisor()); - // Open some file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Create a new pipe. @@ -109,9 +109,9 @@ TEST(TeeTest, RegularFile) { const FileDescriptor wfd(fds[1]); // Attempt to tee from the file. - EXPECT_THAT(tee(inf.get(), wfd.get(), kPageSize, 0), + EXPECT_THAT(tee(in_fd.get(), wfd.get(), kPageSize, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(tee(rfd.get(), inf.get(), kPageSize, 0), + EXPECT_THAT(tee(rfd.get(), in_fd.get(), kPageSize, 0), SyscallFailsWithErrno(EINVAL)); } @@ -142,7 +142,7 @@ TEST(SpliceTest, FromEventFD) { constexpr uint64_t kEventFDValue = 1; int efd; ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds()); - const FileDescriptor inf(efd); + const FileDescriptor in_fd(efd); // Create a new pipe. int fds[2]; @@ -152,7 +152,7 @@ TEST(SpliceTest, FromEventFD) { // Splice 8-byte eventfd value to pipe. constexpr int kEventFDSize = 8; - EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), SyscallSucceedsWithValue(kEventFDSize)); // Contents should be equal. @@ -166,7 +166,7 @@ TEST(SpliceTest, FromEventFD) { TEST(SpliceTest, FromEventFDOffset) { int efd; ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor inf(efd); + const FileDescriptor in_fd(efd); // Create a new pipe. int fds[2]; @@ -179,7 +179,7 @@ TEST(SpliceTest, FromEventFDOffset) { // This is not allowed because eventfd doesn't support pread. constexpr int kEventFDSize = 8; loff_t in_off = 0; - EXPECT_THAT(splice(inf.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), + EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), SyscallFailsWithErrno(EINVAL)); } @@ -200,28 +200,29 @@ TEST(SpliceTest, ToEventFDOffset) { int efd; ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor outf(efd); + const FileDescriptor out_fd(efd); // Attempt to splice 8-byte eventfd value to pipe with offset. // // This is not allowed because eventfd doesn't support pwrite. loff_t out_off = 0; - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_off, kEventFDSize, 0), - SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0), + SyscallFailsWithErrno(EINVAL)); } TEST(SpliceTest, ToPipe) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Fill with some random data. std::vector buf(kPageSize); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(inf.get(), buf.data(), buf.size()), + ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(lseek(inf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + ASSERT_THAT(lseek(in_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); // Create a new pipe. int fds[2]; @@ -230,7 +231,7 @@ TEST(SpliceTest, ToPipe) { const FileDescriptor wfd(fds[1]); // Splice to the pipe. - EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kPageSize, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0), SyscallSucceedsWithValue(kPageSize)); // Contents should be equal. @@ -243,13 +244,13 @@ TEST(SpliceTest, ToPipe) { TEST(SpliceTest, ToPipeOffset) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Fill with some random data. std::vector buf(kPageSize); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(inf.get(), buf.data(), buf.size()), + ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); // Create a new pipe. @@ -261,7 +262,7 @@ TEST(SpliceTest, ToPipeOffset) { // Splice to the pipe. loff_t in_offset = kPageSize / 2; EXPECT_THAT( - splice(inf.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0), + splice(in_fd.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0), SyscallSucceedsWithValue(kPageSize / 2)); // Contents should be equal to only the second part. @@ -286,22 +287,22 @@ TEST(SpliceTest, FromPipe) { // Open the input file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); // Splice to the output file. - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), nullptr, kPageSize, 0), + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, kPageSize, 0), SyscallSucceedsWithValue(kPageSize)); // The offset of the output should be equal to kPageSize. We assert that and // reset to zero so that we can read the contents and ensure they match. - EXPECT_THAT(lseek(outf.get(), 0, SEEK_CUR), + EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(lseek(outf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); // Contents should be equal. std::vector rbuf(kPageSize); - ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()), + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), SyscallSucceedsWithValue(kPageSize)); EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } @@ -321,18 +322,19 @@ TEST(SpliceTest, FromPipeOffset) { // Open the input file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); // Splice to the output file. loff_t out_offset = kPageSize / 2; - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_offset, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kPageSize, 0), + SyscallSucceedsWithValue(kPageSize)); // Content should reflect the splice. We write to a specific offset in the // file, so the internals should now be allocated sparsely. std::vector rbuf(kPageSize); - ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()), + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), SyscallSucceedsWithValue(kPageSize)); std::vector zbuf(kPageSize / 2); memset(zbuf.data(), 0, zbuf.size()); @@ -404,8 +406,6 @@ TEST(SpliceTest, Blocking) { } TEST(TeeTest, Blocking) { - SKIP_IF(IsRunningOnGvisor()); - // Create two new pipes. int first[2], second[2]; ASSERT_THAT(pipe(first), SyscallSucceeds()); @@ -440,6 +440,49 @@ TEST(TeeTest, Blocking) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); } +TEST(TeeTest, BlockingWrite) { + // Create two new pipes. + int first[2], second[2]; + ASSERT_THAT(pipe(first), SyscallSucceeds()); + const FileDescriptor rfd1(first[0]); + const FileDescriptor wfd1(first[1]); + ASSERT_THAT(pipe(second), SyscallSucceeds()); + const FileDescriptor rfd2(second[0]); + const FileDescriptor wfd2(second[1]); + + // Make some data available to be read. + std::vector buf1(kPageSize); + RandomizeBuffer(buf1.data(), buf1.size()); + ASSERT_THAT(write(wfd1.get(), buf1.data(), buf1.size()), + SyscallSucceedsWithValue(kPageSize)); + + // Fill up the write pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd2.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector buf2(pipe_size); + ASSERT_THAT(write(wfd2.get(), buf2.data(), buf2.size()), + SyscallSucceedsWithValue(pipe_size)); + + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd2.get(), buf2.data(), buf2.size()), + SyscallSucceedsWithValue(pipe_size)); + }); + + // Attempt a tee immediately; it should block. + EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0), + SyscallSucceedsWithValue(kPageSize)); + + // Thread should be joinable. + t.Join(); + + // Content should reflect the tee. + std::vector rbuf(kPageSize); + ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_EQ(memcmp(rbuf.data(), buf1.data(), kPageSize), 0); +} + TEST(SpliceTest, NonBlocking) { // Create two new pipes. int first[2], second[2]; @@ -457,8 +500,6 @@ TEST(SpliceTest, NonBlocking) { } TEST(TeeTest, NonBlocking) { - SKIP_IF(IsRunningOnGvisor()); - // Create two new pipes. int first[2], second[2]; ASSERT_THAT(pipe(first), SyscallSucceeds()); @@ -473,6 +514,79 @@ TEST(TeeTest, NonBlocking) { SyscallFailsWithErrno(EAGAIN)); } +TEST(TeeTest, MultiPage) { + // Create two new pipes. + int first[2], second[2]; + ASSERT_THAT(pipe(first), SyscallSucceeds()); + const FileDescriptor rfd1(first[0]); + const FileDescriptor wfd1(first[1]); + ASSERT_THAT(pipe(second), SyscallSucceeds()); + const FileDescriptor rfd2(second[0]); + const FileDescriptor wfd2(second[1]); + + // Make some data available to be read. + std::vector wbuf(8 * kPageSize); + RandomizeBuffer(wbuf.data(), wbuf.size()); + ASSERT_THAT(write(wfd1.get(), wbuf.data(), wbuf.size()), + SyscallSucceedsWithValue(wbuf.size())); + + // Attempt a tee immediately; it should complete. + EXPECT_THAT(tee(rfd1.get(), wfd2.get(), wbuf.size(), 0), + SyscallSucceedsWithValue(wbuf.size())); + + // Content should reflect the tee. + std::vector rbuf(wbuf.size()); + ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0); + ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0); +} + +TEST(SpliceTest, FromPipeMaxFileSize) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill with some random data. + std::vector buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize)); + + // Open the input file. + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor out_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); + + EXPECT_THAT(ftruncate(out_fd.get(), 13 << 20), SyscallSucceeds()); + EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_END), + SyscallSucceedsWithValue(13 << 20)); + + // Set our file size limit. + sigset_t set; + sigemptyset(&set); + sigaddset(&set, SIGXFSZ); + TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); + rlimit rlim = {}; + rlim.rlim_cur = rlim.rlim_max = (13 << 20); + EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &rlim), SyscallSucceeds()); + + // Splice to the output file. + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3 * kPageSize, 0), + SyscallFailsWithErrno(EFBIG)); + + // Contents should be equal. + std::vector rbuf(kPageSize); + ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); +} + } // namespace } // namespace testing -- cgit v1.2.3 From 60fe8719e172f76aa5cfd8cd80a35c3e648701a3 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 17 Sep 2019 14:45:41 -0700 Subject: Automated rollback of changelist 268047073 PiperOrigin-RevId: 269658971 --- pkg/tcpip/header/ipv6.go | 25 ++++ pkg/tcpip/network/ipv6/BUILD | 2 + pkg/tcpip/network/ipv6/icmp_test.go | 6 - pkg/tcpip/network/ipv6/ipv6_test.go | 258 ++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/nic.go | 71 +++++++++- pkg/tcpip/stack/stack.go | 6 +- 6 files changed, 353 insertions(+), 15 deletions(-) create mode 100644 pkg/tcpip/network/ipv6/ipv6_test.go (limited to 'pkg/tcpip/header') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 093850e25..9d3abc0e4 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -76,6 +76,13 @@ const ( // IPv6Version is the version of the ipv6 protocol. IPv6Version = 6 + // IPv6AllNodesMulticastAddress is a link-local multicast group that + // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all nodes on a link. + // + // The address is ff02::1. + IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 @@ -221,6 +228,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool { return addr[0] == 0xff } +// IsV6UnicastAddress determines if the provided address is a valid IPv6 +// unicast (and specified) address. That is, IsV6UnicastAddress returns +// true if addr contains IPv6AddressSize bytes, is not the unspecified +// address and is not a multicast address. +func IsV6UnicastAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + + // Must not be unspecified + if addr == IPv6Any { + return false + } + + // Return if not a multicast. + return addr[0] != 0xff +} + // SolicitedNodeAddr computes the solicited-node multicast address. This is // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 // address. diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index a471abbfb..f06622a8b 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -26,6 +26,7 @@ go_test( size = "small", srcs = [ "icmp_test.go", + "ipv6_test.go", "ndp_test.go", ], embed = [":ipv6"], @@ -37,6 +38,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a6a1a5232..653d984e9 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -222,9 +222,6 @@ func newTestContext(t *testing.T) *testContext { if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { - t.Fatalf("AddAddress sn lladdr0: %v", err) - } c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) @@ -234,9 +231,6 @@ func newTestContext(t *testing.T) *testContext { if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { - t.Fatalf("AddAddress sn lladdr1: %v", err) - } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) if err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go new file mode 100644 index 000000000..57bcd5455 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -0,0 +1,258 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // The least significant 3 bytes are the same as addr2 so both addr2 and + // addr3 will have the same solicited-node address. + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" +) + +// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the +// expected Neighbor Advertisement received count after receiving the packet. +func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + // Receive ICMP packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stats := s.Stats().ICMP.V6PacketsReceived + + if got := stats.NeighborAdvert.Value(); got != want { + t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) + } +} + +// testReceiveICMP tests receiving a UDP packet from src to dst. want is the +// expected UDP received count after receiving the packet. +func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + + // Receive UDP Packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + + // UDP pseudo-header checksum. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) + + // UDP checksum + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stat := s.Stats().UDP.PacketsReceived + + if got := stat.Value(); got != want { + t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) + } +} + +// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and +// UDP packets destined to the IPv6 link-local all-nodes multicast address. +func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { + tests := []struct { + name string + protocolName string + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.ProtocolName6, testReceiveICMP}, + {"UDP", udp.ProtocolName, testReceiveUDP}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should receive a packet destined to the all-nodes + // multicast address. + test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) + }) + } +} + +// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP +// packets destined to the IPv6 solicited-node address of an assigned IPv6 +// address. +func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + tests := []struct { + name string + protocolName string + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.ProtocolName6, testReceiveICMP}, + {"UDP", udp.ProtocolName, testReceiveUDP}, + } + + snmc := header.SolicitedNodeAddr(addr2) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as we haven't added + // those addresses. + test.rxf(t, s, e, addr1, snmc, 0) + + if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + } + + // Should receive a packet destined to the solicited + // node address of addr2/addr3 now that we have added + // added addr2. + test.rxf(t, s, e, addr1, snmc, 1) + + if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have added addr3. + test.rxf(t, s, e, addr1, snmc, 2) + + if err := s.RemoveAddress(1, addr2); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have removed addr2. + test.rxf(t, s, e, addr1, snmc, 3) + + if err := s.RemoveAddress(1, addr3); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as both of them got + // removed. + test.rxf(t, s, e, addr1, snmc, 3) + }) + } +} + +// TestAddIpv6Address tests adding IPv6 addresses. +func TestAddIpv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + }{ + // This test is in response to b/140943433. + { + "Nil", + tcpip.Address([]byte(nil)), + }, + { + "ValidUnicast", + addr1, + }, + { + "ValidLinkLocalUnicast", + lladdr0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, nil, stack.Options{}) + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + } + }) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43719085e..a719058b4 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -102,6 +102,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + } + + return nil +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -307,6 +326,8 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP address before adding them. + // Sanity check. id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if _, ok := n.endpoints[id]; ok { @@ -339,6 +360,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 unicast address, join the solicited-node + // multicast address. + if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -467,13 +497,27 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || r.getKind() != permanent { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } r.setKind(permanentExpired) - r.decRefLocked() + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } + + // At this point the endpoint is deleted. + + // If we are removing an IPv6 unicast address, leave the solicited-node + // multicast address. + if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -491,6 +535,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -518,6 +569,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -802,11 +860,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a961e8ebe..1fe21b68e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -633,7 +633,7 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, s.nics[id] = n if enabled { - n.attachLinkEndpoint() + return n.enable() } return nil @@ -680,9 +680,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - nic.attachLinkEndpoint() - - return nil + return nic.enable() } // CheckNIC checks if a NIC is usable. -- cgit v1.2.3