diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/parse/parse.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/internal/fragmentation/reassembler.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 367 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_test.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/loopback_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/utils/utils.go | 8 |
18 files changed, 248 insertions, 272 deletions
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index ebb4b2c1d..1c913b5e1 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -60,9 +60,13 @@ func IPv4(pkt *stack.PacketBuffer) bool { return false } ipHdr = header.IPv4(hdr) + length := int(ipHdr.TotalLength()) - len(hdr) + if length < 0 { + return false + } pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr)) + pkt.Data().CapLength(length) return true } diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 90075a70c..56b76a284 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -167,8 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s resPkt := r.holes[0].pkt for i := 1; i < len(r.holes); i++ { - fragData := r.holes[i].pkt.Data() - resPkt.Data().ReadFromData(fragData, fragData.Size()) + stack.MergeFragment(resPkt, r.holes[i].pkt) } return resPkt, r.proto, true, memConsumed, nil } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 74aad126c..bd63e0289 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) @@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, false); err != nil { - t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 3c8a39973..5f45b9ee6 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -389,8 +389,8 @@ func TestForwarding(t *testing.T) { }, }) - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index e457be3cf..040cd4bc8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -673,8 +673,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 880290b4b..febbb3f38 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1472,13 +1472,19 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + // Calculate the number of octets parsed from data. We want to remove all + // the data except the unparsed portion located at the end, which its size + // is extHdr.Buf.Size(). + trim := pkt.Data().Size() - extHdr.Buf.Size() + // For unfragmented packets, extHdr still contains the transport header. // Get rid of it. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) - pkt.Data().Replace(extHdr.Buf) + trim += pkt.TransportHeader().View().Size() + + pkt.Data().DeleteFront(trim) stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index faf6a782e..30325160a 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -3329,8 +3329,8 @@ func TestForwarding(t *testing.T) { }, }) - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) } transportProtocol := header.ICMPv6ProtocolNumber diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 52b9a200c..570c6c00c 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -878,8 +878,9 @@ func TestNDPValidation(t *testing.T) { s, ep := setup(t) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } stats := s.Stats().ICMP.V6.PacketsReceived diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 63ab31083..84aa6a9e4 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -74,6 +74,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/atomicbitops", + "//pkg/buffer", "//pkg/ilist", "//pkg/log", "//pkg/rand", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7d3725681..ff555722e 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -367,8 +367,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f }}, }) - // Enable forwarding. - s.SetForwarding(proto.Number(), true) + protoNum := proto.Number() + if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) + } // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index c585b81b2..d4ac9e1f8 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1220,8 +1220,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { NDPDisp: &ndpDisp, })}, }) - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } e := channel.New(1, 1280, linkAddr1) @@ -1424,8 +1424,8 @@ func TestRouterDiscovery(t *testing.T) { } } - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } if err := s.CreateNIC(1, e); err != nil { @@ -1626,8 +1626,8 @@ func TestPrefixDiscovery(t *testing.T) { } } - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) @@ -1893,8 +1893,8 @@ func TestAutoGenAddr(t *testing.T) { })}, }) - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } if err := s.CreateNIC(1, e); err != nil { @@ -4771,8 +4771,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { // or routers, or auto-generated address. for _, forwarding := range [...]bool{true, false} { t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { case e := <-ndpDisp.routerC: @@ -5353,8 +5353,8 @@ func TestRouterSolicitation(t *testing.T) { name: "Handle RAs always", handleRAs: ipv6.HandlingRAsAlwaysEnabled, afterFirstRS: func(t *testing.T, s *stack.Stack) { - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } }, }, @@ -5481,11 +5481,17 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, false) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) + } }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } }, }, diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index fc3c54e34..e2e073091 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -16,9 +16,10 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + tcpipbuffer "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -39,7 +40,7 @@ type PacketBufferOptions struct { // Data is the initial unparsed data for the new packet. If set, it will be // owned by the new packet. - Data buffer.VectorisedView + Data tcpipbuffer.VectorisedView // IsForwardedPacket identifies that the PacketBuffer being created is for a // forwarded packet. @@ -56,6 +57,34 @@ type PacketBufferOptions struct { // empty. Use of PacketBuffer in any other order is unsupported. // // PacketBuffer must be created with NewPacketBuffer. +// +// Internal structure: A PacketBuffer holds a pointer to buffer.Buffer, which +// exposes a logically-contiguous byte storage. The underlying storage structure +// is abstracted out, and should not be a concern here for most of the time. +// +// |- reserved ->| +// |--->| consumed (incoming) +// 0 V V +// +--------+----+----+--------------------+ +// | | | | current data ... | (buf) +// +--------+----+----+--------------------+ +// ^ | +// |<---| pushed (outgoing) +// +// When a PacketBuffer is created, a `reserved` header region can be specified, +// which stack pushes headers in this region for an outgoing packet. There could +// be no such region for an incoming packet, and `reserved` is 0. The value of +// `reserved` never changes in the entire lifetime of the packet. +// +// Outgoing Packet: When a header is pushed, `pushed` gets incremented by the +// pushed length, and the current value is stored for each header. PacketBuffer +// substracts this value from `reserved` to compute the starting offset of each +// header in `buf`. +// +// Incoming Packet: When a header is consumed (a.k.a. parsed), the current +// `consumed` value is stored for each header, and it gets incremented by the +// consumed length. PacketBuffer adds this value to `reserved` to compute the +// starting offset of each header in `buf`. type PacketBuffer struct { _ sync.NoCopy @@ -63,28 +92,16 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // data holds the payload of the packet. - // - // For inbound packets, Data is initially the whole packet. Then gets moved to - // headers via PacketHeader.Consume, when the packet is being parsed. - // - // For outbound packets, Data is the innermost layer, defined by the protocol. - // Headers are pushed in front of it via PacketHeader.Push. - // - // The bytes backing Data are immutable, a.k.a. users shouldn't write to its - // backing storage. - data buffer.VectorisedView + // buf is the underlying buffer for the packet. See struct level docs for + // details. + buf *buffer.Buffer + reserved int + pushed int + consumed int // headers stores metadata about each header. headers [numHeaderType]headerInfo - // header is the internal storage for outbound packets. Headers will be pushed - // (prepended) on this storage as the packet is being constructed. - // - // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and - // data are held in the same underlying buffer storage. - header buffer.Prependable - // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() // returns false. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol @@ -131,10 +148,14 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - data: opts.Data, + buf: &buffer.Buffer{}, } if opts.ReserveHeaderBytes != 0 { - pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + pk.buf.AppendOwned(make([]byte, opts.ReserveHeaderBytes)) + pk.reserved = opts.ReserveHeaderBytes + } + for _, v := range opts.Data.Views() { + pk.buf.AppendOwned(v) } if opts.IsForwardedPacket { pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket @@ -145,13 +166,13 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { // ReservedHeaderBytes returns the number of bytes initially reserved for // headers. func (pk *PacketBuffer) ReservedHeaderBytes() int { - return pk.header.UsedLength() + pk.header.AvailableLength() + return pk.reserved } // AvailableHeaderBytes returns the number of bytes currently available for // headers. This is relevant to PacketHeader.Push method only. func (pk *PacketBuffer) AvailableHeaderBytes() int { - return pk.header.AvailableLength() + return pk.reserved - pk.pushed } // LinkHeader returns the handle to link-layer header. @@ -180,24 +201,18 @@ func (pk *PacketBuffer) TransportHeader() PacketHeader { // HeaderSize returns the total size of all headers in bytes. func (pk *PacketBuffer) HeaderSize() int { - // Note for inbound packets (Consume called), headers are not stored in - // pk.header. Thus, calculation of size of each header is needed. - var size int - for i := range pk.headers { - size += len(pk.headers[i].buf) - } - return size + return pk.pushed + pk.consumed } // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.data.Size() + return int(pk.buf.Size()) - pk.headerOffset() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize + return int(pk.buf.Size()) + packetBufferStructSize } // Data returns the handle to data portion of pk. @@ -206,61 +221,65 @@ func (pk *PacketBuffer) Data() PacketData { } // Views returns the underlying storage of the whole packet. -func (pk *PacketBuffer) Views() []buffer.View { - // Optimization for outbound packets that headers are in pk.header. - useHeader := true - for i := range pk.headers { - if !canUseHeader(&pk.headers[i]) { - useHeader = false - break - } - } +func (pk *PacketBuffer) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := pk.headerOffset() + pk.buf.SubApply(offset, int(pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views +} - dataViews := pk.data.Views() - - var vs []buffer.View - if useHeader { - vs = make([]buffer.View, 0, 1+len(dataViews)) - vs = append(vs, pk.header.View()) - } else { - vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) - for i := range pk.headers { - if v := pk.headers[i].buf; len(v) > 0 { - vs = append(vs, v) - } - } - } - return append(vs, dataViews...) +func (pk *PacketBuffer) headerOffset() int { + return pk.reserved - pk.pushed +} + +func (pk *PacketBuffer) headerOffsetOf(typ headerType) int { + return pk.reserved + pk.headers[typ].offset } -func canUseHeader(h *headerInfo) bool { - // h.offset will be negative if the header was pushed in to prependable - // portion, or doesn't matter when it's empty. - return len(h.buf) == 0 || h.offset < 0 +func (pk *PacketBuffer) dataOffset() int { + return pk.reserved + pk.consumed } -func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { +func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] - if h.buf != nil { + if h.length > 0 { panic(fmt.Sprintf("push must not be called twice: type %s", typ)) } - h.buf = buffer.View(pk.header.Prepend(size)) - h.offset = -pk.header.UsedLength() - return h.buf + if pk.pushed+size > pk.reserved { + panic("not enough headroom reserved") + } + pk.pushed += size + h.offset = -pk.pushed + h.length = size + return pk.headerView(typ) } -func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { +func (pk *PacketBuffer) consume(typ headerType, size int) (v tcpipbuffer.View, consumed bool) { h := &pk.headers[typ] - if h.buf != nil { + if h.length > 0 { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.data.PullUp(size) + if pk.headerOffset()+pk.consumed+size > int(pk.buf.Size()) { + return nil, false + } + h.offset = pk.consumed + h.length = size + pk.consumed += size + return pk.headerView(typ), true +} + +func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { + h := &pk.headers[typ] + if h.length == 0 { + return nil + } + v, ok := pk.buf.PullUp(pk.headerOffsetOf(typ), h.length) if !ok { - return + panic("PullUp failed") } - pk.data.TrimFront(size) - h.buf = v - return h.buf, true + return v } // Clone makes a shallow copy of pk. @@ -270,9 +289,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - data: pk.data.Clone(nil), + buf: pk.buf, + reserved: pk.reserved, + pushed: pk.pushed, + consumed: pk.consumed, headers: pk.headers, - header: pk.header, Hash: pk.Hash, Owner: pk.Owner, GSOOptions: pk.GSOOptions, @@ -306,9 +327,11 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - newPk := NewPacketBuffer(PacketBufferOptions{ - Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), - }) + newPk := &PacketBuffer{ + buf: pk.buf, + // Treat unfilled header portion as reserved. + reserved: pk.AvailableHeaderBytes(), + } // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to // maintain this flag in the packet. Currently conntrack needs this flag to // tell if a noop connection should be inserted at Input hook. Once conntrack @@ -322,15 +345,12 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // headerInfo stores metadata about a header in a packet. type headerInfo struct { - // buf is the memorized slice for both prepended and consumed header. - // When header is prepended, buf serves as memorized value, which is a slice - // of pk.header. When header is consumed, buf is the slice pulled out from - // pk.Data, which is the only place to hold this header. - buf buffer.View - - // offset will be a negative number denoting the offset where this header is - // from the end of pk.header, if it is prepended. Otherwise, zero. + // offset is the offset of the header in pk.buf relative to + // pk.buf[pk.reserved]. See the PacketBuffer struct for details. offset int + + // length is the length of this header. + length int } // PacketHeader is a handle object to a header in the underlying packet. @@ -340,14 +360,14 @@ type PacketHeader struct { } // View returns the underlying storage of h. -func (h PacketHeader) View() buffer.View { - return h.pk.headers[h.typ].buf +func (h PacketHeader) View() tcpipbuffer.View { + return h.pk.headerView(h.typ) } // Push pushes size bytes in the front of its residing packet, and returns the // backing storage. Callers may only call one of Push or Consume once on each // header in the lifetime of the underlying packet. -func (h PacketHeader) Push(size int) buffer.View { +func (h PacketHeader) Push(size int) tcpipbuffer.View { return h.pk.push(h.typ, size) } @@ -356,7 +376,7 @@ func (h PacketHeader) Push(size int) buffer.View { // size, consumed will be false, and the state of h will not be affected. // Callers may only call one of Push or Consume once on each header in the // lifetime of the underlying packet. -func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { +func (h PacketHeader) Consume(size int) (v tcpipbuffer.View, consumed bool) { return h.pk.consume(h.typ, size) } @@ -367,55 +387,84 @@ type PacketData struct { // PullUp returns a contiguous view of size bytes from the beginning of d. // Callers should not write to or keep the view for later use. -func (d PacketData) PullUp(size int) (buffer.View, bool) { - return d.pk.data.PullUp(size) +func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { + return d.pk.buf.PullUp(d.pk.dataOffset(), size) } // DeleteFront removes count from the beginning of d. It panics if count > // d.Size(). All backing storage references after the front of the d are // invalidated. func (d PacketData) DeleteFront(count int) { - d.pk.data.TrimFront(count) + if !d.pk.buf.Remove(d.pk.dataOffset(), count) { + panic("count > d.Size()") + } } // CapLength reduces d to at most length bytes. func (d PacketData) CapLength(length int) { - d.pk.data.CapLength(length) + if length < 0 { + panic("length < 0") + } + if currLength := d.Size(); currLength > length { + trim := currLength - length + d.pk.buf.Remove(int(d.pk.buf.Size())-trim, trim) + } } // Views returns the underlying storage of d in a slice of Views. Caller should // not modify the returned slice. -func (d PacketData) Views() []buffer.View { - return d.pk.data.Views() +func (d PacketData) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := d.pk.dataOffset() + d.pk.buf.SubApply(offset, int(d.pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views } // AppendView appends v into d, taking the ownership of v. -func (d PacketData) AppendView(v buffer.View) { - d.pk.data.AppendView(v) +func (d PacketData) AppendView(v tcpipbuffer.View) { + d.pk.buf.AppendOwned(v) } -// ReadFromData moves at most count bytes from the beginning of srcData to the -// end of d and returns the number of bytes moved. -func (d PacketData) ReadFromData(srcData PacketData, count int) int { - return srcData.pk.data.ReadToVV(&d.pk.data, count) +// MergeFragment appends the data portion of frag to dst. It takes ownership of +// frag and frag should not be used again. +func MergeFragment(dst, frag *PacketBuffer) { + frag.buf.TrimFront(int64(frag.dataOffset())) + dst.buf.Merge(frag.buf) } // ReadFromVV moves at most count bytes from the beginning of srcVV to the end // of d and returns the number of bytes moved. -func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { - return srcVV.ReadToVV(&d.pk.data, count) +func (d PacketData) ReadFromVV(srcVV *tcpipbuffer.VectorisedView, count int) int { + done := 0 + for _, v := range srcVV.Views() { + if len(v) < count { + count -= len(v) + done += len(v) + d.pk.buf.AppendOwned(v) + } else { + v = v[:count] + count -= len(v) + done += len(v) + d.pk.buf.Append(v) + break + } + } + srcVV.TrimFront(done) + return done } // Size returns the number of bytes in the data payload of the packet. func (d PacketData) Size() int { - return d.pk.data.Size() + return int(d.pk.buf.Size()) - d.pk.dataOffset() } // AsRange returns a Range representing the current data payload of the packet. func (d PacketData) AsRange() Range { return Range{ pk: d.pk, - offset: d.pk.HeaderSize(), + offset: d.pk.dataOffset(), length: d.Size(), } } @@ -425,17 +474,12 @@ func (d PacketData) AsRange() Range { // // This method exists for compatibility between PacketBuffer and VectorisedView. // It may be removed later and should be used with care. -func (d PacketData) ExtractVV() buffer.VectorisedView { - return d.pk.data -} - -// Replace replaces the data portion of the packet with vv, taking the ownership -// of vv. -// -// This method exists for compatibility between PacketBuffer and VectorisedView. -// It may be removed later and should be used with care. -func (d PacketData) Replace(vv buffer.VectorisedView) { - d.pk.data = vv +func (d PacketData) ExtractVV() tcpipbuffer.VectorisedView { + var vv tcpipbuffer.VectorisedView + d.pk.buf.SubApply(d.pk.dataOffset(), d.pk.Size(), func(v []byte) { + vv.AppendView(v) + }) + return vv } // Range represents a contiguous subportion of a PacketBuffer. @@ -479,9 +523,9 @@ func (r Range) Capped(max int) Range { // AsView returns the backing storage of r if possible. It will allocate a new // View if r spans multiple pieces internally. Caller should not write to the // returned View in any way. -func (r Range) AsView() buffer.View { +func (r Range) AsView() tcpipbuffer.View { var allocated bool - var v buffer.View + var v tcpipbuffer.View r.iterate(func(b []byte) { if v == nil { // v has not been assigned, allowing first view to be returned. @@ -502,7 +546,7 @@ func (r Range) AsView() buffer.View { } // ToOwnedView returns a owned copy of data in r. -func (r Range) ToOwnedView() buffer.View { +func (r Range) ToOwnedView() tcpipbuffer.View { if r.length == 0 { return nil } @@ -523,63 +567,7 @@ func (r Range) Checksum() uint16 { // iterate calls fn for each piece in r. fn is always called with a non-empty // slice. func (r Range) iterate(fn func([]byte)) { - w := window{ - offset: r.offset, - length: r.length, - } - // Header portion. - for i := range r.pk.headers { - if b := w.process(r.pk.headers[i].buf); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - // Data portion. - if !w.isDone() { - for _, v := range r.pk.data.Views() { - if b := w.process(v); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - } -} - -// window represents contiguous region of byte stream. User would call process() -// to input bytes, and obtain a subslice that is inside the window. -type window struct { - offset int - length int -} - -// isDone returns true if the window has passed and further process() calls will -// always return an empty slice. This can be used to end processing early. -func (w *window) isDone() bool { - return w.length == 0 -} - -// process feeds b in and returns a subslice that is inside the window. The -// returned slice will be a subslice of b, and it does not keep b after method -// returns. This method may return an empty slice if nothing in b is inside the -// window. -func (w *window) process(b []byte) (inWindow []byte) { - if w.offset >= len(b) { - w.offset -= len(b) - return nil - } - if w.offset > 0 { - b = b[w.offset:] - w.offset = 0 - } - if w.length < len(b) { - b = b[:w.length] - } - w.length -= len(b) - return b + r.pk.buf.SubApply(r.offset, r.length, fn) } // PayloadSince returns packet payload starting from and including a particular @@ -587,21 +575,14 @@ func (w *window) process(b []byte) (inWindow []byte) { // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. -func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.data.Size() - for _, hinfo := range h.pk.headers[h.typ:] { - size += len(hinfo.buf) +func PayloadSince(h PacketHeader) tcpipbuffer.View { + offset := h.pk.headerOffset() + for i := headerType(0); i < h.typ; i++ { + offset += h.pk.headers[i].length } - - v := make(buffer.View, 0, size) - - for _, hinfo := range h.pk.headers[h.typ:] { - v = append(v, hinfo.buf...) - } - - for _, view := range h.pk.data.Views() { - v = append(v, view...) - } - - return v + return Range{ + pk: h.pk, + offset: offset, + length: int(h.pk.buf.Size()) - offset, + }.ToOwnedView() } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index bd4eb4fed..1c1aeb950 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -444,23 +444,8 @@ func TestPacketBufferData(t *testing.T) { checkData(t, pkt, []byte(tc.data+s)) }) - // ReadFromData/VV + // ReadFromVV for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { - t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { - s := "TO READ" - otherPkt := NewPacketBuffer(PacketBufferOptions{ - Data: vv(s, s), - }) - s += s - - pkt := tc.makePkt(t) - pkt.Data().ReadFromData(otherPkt.Data(), n) - - if n < len(s) { - s = s[:n] - } - checkData(t, pkt, []byte(tc.data+s)) - }) t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { s := "TO READ" srcVV := vv(s, s) @@ -487,16 +472,6 @@ func TestPacketBufferData(t *testing.T) { t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) } }) - - // Replace - t.Run("Replace", func(t *testing.T) { - s := "REPLACED" - - pkt := tc.makePkt(t) - pkt.Data().Replace(vv(s)) - - checkData(t, pkt, []byte(s)) - }) }) } } @@ -568,7 +543,7 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { t.Helper() if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { - t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want) } if got := pkt.Data().Size(); got != len(want) { t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3d9e1e286..483a960c8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -492,9 +492,9 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables packet forwarding between NICs for the -// passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { protocol, ok := s.networkProtocols[protocolNum] if !ok { return &tcpip.ErrUnknownProtocol{} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2c40cc43..ff88b1bd3 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4220,8 +4220,8 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) } - if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) @@ -4275,8 +4275,8 @@ func TestFindRouteWithForwarding(t *testing.T) { // Disabling forwarding when the route is dependent on forwarding being // enabled should make the route invalid. - if err := s.SetForwarding(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) } { err := send(r, data) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index dbd279c94..42bc53328 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -475,11 +475,11 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) } - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3df1bbd68..87d36e1dd 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -714,11 +714,11 @@ func TestExternalLoopbackTraffic(t *testing.T) { } if test.forwarding { - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 8fd9be32b..c8b9c9b5c 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -224,11 +224,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) } - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err) } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { |