From 47bc115158397024841aa3747be7558b2c317cbb Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 21 Apr 2021 18:07:13 -0700 Subject: Only carry GSO options in the packet buffer With this change, GSO options no longer needs to be passed around as a function argument in the write path. This change is done in preparation for a later change that defers segmentation, and may change GSO options for a packet as it flows down the stack. Updates #170. PiperOrigin-RevId: 369774872 --- pkg/tcpip/transport/icmp/endpoint.go | 4 ++-- pkg/tcpip/transport/raw/endpoint.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 20 +++++++++++--------- pkg/tcpip/transport/tcp/endpoint.go | 20 +++++++++----------- pkg/tcpip/transport/tcp/protocol.go | 6 +++--- pkg/tcpip/transport/tcp/snd.go | 4 ++-- pkg/tcpip/transport/tcp/testing/context/context.go | 4 ++-- pkg/tcpip/transport/udp/endpoint.go | 2 +- 8 files changed, 31 insertions(+), 31 deletions(-) (limited to 'pkg/tcpip/transport') diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 33ed78f54..9948f305b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -431,7 +431,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() return err } @@ -478,7 +478,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 10453a42a..bcec3d2e7 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -354,7 +354,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp Data: buffer.View(payloadBytes).ToVectorisedView(), }) pkt.Owner = owner - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := route.WritePacket(stack.NetworkHeaderParams{ Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 7bc6b08f0..524d5cabf 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -721,14 +721,14 @@ type tcpFields struct { func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error { tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. - if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { + if err := e.sendTCP(r, tf, buffer.VectorisedView{}, stack.GSO{}); err != nil { e.stats.SendErrors.SynSendToNetworkFailed.Increment() } putOptions(tf.opts) return nil } -func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO) tcpip.Error { tf.txHash = e.txHash if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() @@ -738,7 +738,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV return nil } -func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { +func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso stack.GSO) { optLen := len(tf.opts) tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) pkt.TransportProtocolNumber = header.TCPProtocolNumber @@ -755,7 +755,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. - if gso != nil && gso.NeedsCsum { + if gso.Type != stack.GSONone && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We // calculate a checksum of the pseudo-header and save it in the // TCP header, then the kernel calculate a checksum of the @@ -767,7 +767,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO, owner tcpip.PacketOwner) tcpip.Error { // We need to shallow clone the VectorisedView here as ReadToView will // split the VectorisedView and Trim underlying views as it splits. Not // doing the clone here will cause the underlying views of data itself @@ -799,13 +799,14 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Data().ReadFromVV(&data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkt.GSOOptions = gso pkts.PushBack(pkt) } if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } - sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) + sent, err := r.WritePackets(pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) if err != nil { r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) } @@ -815,13 +816,13 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO, owner tcpip.PacketOwner) tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > math.MaxUint16 { tf.rcvWnd = math.MaxUint16 } - if r.Loop()&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + if r.Loop()&stack.PacketLoop == 0 && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { return sendTCPBatch(r, tf, data, gso, owner) } @@ -829,6 +830,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, Data: data, }) + pkt.GSOOptions = gso pkt.Hash = tf.txHash pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) @@ -836,7 +838,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f25dc781a..50f72bf38 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -599,7 +599,7 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - gso *stack.GSO + gso stack.GSO // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` @@ -2943,28 +2943,26 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { } func (e *endpoint) initHardwareGSO() { - gso := &stack.GSO{} switch e.route.NetProto() { case header.IPv4ProtocolNumber: - gso.Type = stack.GSOTCPv4 - gso.L3HdrLen = header.IPv4MinimumSize + e.gso.Type = stack.GSOTCPv4 + e.gso.L3HdrLen = header.IPv4MinimumSize case header.IPv6ProtocolNumber: - gso.Type = stack.GSOTCPv6 - gso.L3HdrLen = header.IPv6MinimumSize + e.gso.Type = stack.GSOTCPv6 + e.gso.L3HdrLen = header.IPv6MinimumSize default: panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) } - gso.NeedsCsum = true - gso.CsumOffset = header.TCPChecksumOffset - gso.MaxSize = e.route.GSOMaxSize() - e.gso = gso + e.gso.NeedsCsum = true + e.gso.CsumOffset = header.TCPChecksumOffset + e.gso.MaxSize = e.route.GSOMaxSize() } func (e *endpoint) initGSO() { if e.route.HasHardwareGSOCapability() { e.initHardwareGSO() } else if e.route.HasSoftwareGSOCapability() { - e.gso = &stack.GSO{ + e.gso = stack.GSO{ MaxSize: e.route.GSOMaxSize(), Type: stack.GSOSW, NeedsCsum: false, diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index fe0d7f10f..a3d1aa1a3 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -159,8 +159,8 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, // replyWithReset replies to the given segment with a reset segment. // // If the passed TTL is 0, then the route's default TTL will be used. -func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { - route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) +func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { + route, err := st.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err } @@ -200,7 +200,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error seq: seq, ack: ack, rcvWnd: 0, - }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */) + }, buffer.VectorisedView{}, stack.GSO{}, nil /* PacketOwner */) } // SetOption implements stack.TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index cf2e8dcd8..2b32cb7b2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -180,7 +180,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint }, RTO: 1 * time.Second, }, - gso: ep.gso != nil, + gso: ep.gso.Type != stack.GSONone, } if s.gso { @@ -830,7 +830,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // If GSO is not in use then cap available to // maxPayloadSize. When GSO is in use the gVisor GSO logic or // the host GSO logic will cap the segment to the correct size. - if s.ep.gso == nil && available > s.MaxPayloadSize { + if s.ep.gso.Type == stack.GSONone && available > s.MaxPayloadSize { available = s.MaxPayloadSize } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 7578d64ec..16f8c5212 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -331,8 +331,8 @@ func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) + if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize) } checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c9f2f3efc..f7dd50d35 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -848,7 +848,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: ProtocolNumber, TTL: ttl, TOS: tos, -- cgit v1.2.3