summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go84
-rw-r--r--pkg/tcpip/transport/tcp/connect.go16
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go7
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go13
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go12
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go31
-rw-r--r--pkg/tcpip/transport/tcp/segment.go47
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go79
9 files changed, 183 insertions, 110 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6b3238d6b..47982ca41 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
- netProto = s.route.NetProto
+ netProto = s.netProto
}
+
+ route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return nil, err
+ }
+ route.ResolveWith(s.remoteLinkAddr)
+
n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6Only
n.ID = s.id
- n.boundNICID = s.route.NICID()
- n.route = s.route.Clone()
- n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.boundNICID = s.nicID
+ n.route = route
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto}
n.rcvBufSize = int(l.rcvWnd)
n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
@@ -225,7 +232,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- return n
+ return n, nil
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
@@ -236,7 +243,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ if err != nil {
+ return nil, err
+ }
// Lock the endpoint before registering to ensure that no out of
// band changes are possible due to incoming packets etc till
@@ -467,7 +477,7 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
-func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
e.rcvListMu.Unlock()
@@ -477,8 +487,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
// must be sent in response to a SYN-ACK while in the listen
// state to prevent completing a handshake from an old SYN.
- replyWithReset(s, e.sendTOS, e.ttl)
- return
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
switch {
@@ -492,13 +501,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
- return
+ return nil
}
ctx.synRcvdCount.dec()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
} else {
// If cookies are in use but the endpoint accept queue
// is full then drop the syn.
@@ -506,10 +515,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ route.ResolveWith(s.remoteLinkAddr)
+
// Send SYN without window scaling because we currently
// don't encode this information in the cookie.
//
@@ -523,9 +539,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
TS: opts.TS,
TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: calculateAdvertisedMSS(e.userMSS, s.route),
+ MSS: calculateAdvertisedMSS(e.userMSS, route),
}
- e.sendSynTCP(&s.route, tcpFields{
+ fields := tcpFields{
id: s.id,
ttl: e.ttl,
tos: e.sendTOS,
@@ -533,8 +549,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
seq: cookie,
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
- }, synOpts)
+ }
+ if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ return err
+ }
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ return nil
}
case (s.flags & header.TCPFlagAck) != 0:
@@ -547,7 +567,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
if !ctx.synRcvdCount.synCookiesInUse() {
@@ -566,8 +586,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// The only time we should reach here when a connection
// was opened and closed really quickly and a delayed
// ACK was received from the sender.
- replyWithReset(s, e.sendTOS, e.ttl)
- return
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
iss := s.ackNumber - 1
@@ -587,7 +606,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
@@ -608,7 +627,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ if err != nil {
+ return err
+ }
n.mu.Lock()
@@ -622,7 +644,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- return
+ return nil
}
// Register new endpoint so that packets are routed to it.
@@ -632,7 +654,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- return
+ return err
}
n.isRegistered = true
@@ -670,12 +692,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
+ return nil
+
+ default:
+ return nil
}
}
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
-func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.v6only
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
@@ -714,12 +740,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
- return nil
+ return
}
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
s := e.segmentQueue.dequeue()
- e.handleListenSegment(ctx, s)
+ // TODO(gvisor.dev/issue/4690): Better handle errors instead of
+ // silently dropping.
+ _ = e.handleListenSegment(ctx, s)
s.decRef()
}
close(e.drainDone)
@@ -738,7 +766,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
break
}
- e.handleListenSegment(ctx, s)
+ // TODO(gvisor.dev/issue/4690): Better handle errors instead of
+ // silently dropping.
+ _ = e.handleListenSegment(ctx, s)
s.decRef()
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index c890e2326..2facbebec 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -293,9 +293,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
MSS: amss,
}
if ttl == 0 {
- ttl = s.route.DefaultTTL()
+ ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&s.route, tcpFields{
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -356,7 +356,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&s.route, tcpFields{
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -771,7 +771,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
// TCP header, then the kernel calculate a checksum of the
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
- } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ } else if r.RequiresTXTransportChecksum() {
xsum = header.ChecksumVV(pkt.Data, xsum)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
@@ -1044,13 +1044,13 @@ func (e *endpoint) transitionToStateCloseLocked() {
// only when the endpoint is in StateClose and we want to deliver the segment
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
- ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
- ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
+ ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
if ep == nil {
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */)
s.decRef()
return
}
@@ -1626,7 +1626,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
}
for _, netProto := range netProtos {
- if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil {
tcpEP := listenEP.(*endpoint)
if EndpointState(tcpEP.State()) == StateListen {
reuseTW = func() {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 98aecab9e..21162f01a 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -172,10 +172,11 @@ func (d *dispatcher) wait() {
d.wg.Wait()
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
- s := newSegment(r, id, pkt)
- if !s.parse() {
+
+ s := newIncomingSegment(id, pkt)
+ if !s.parse(pkt.RXTransportChecksumValidated) {
ep.stack.Stats().MalformedRcvdPackets.Increment()
ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b817ab6ef..258f9f1bb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1425,7 +1425,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
// Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
+ s := newOutgoingSegment(e.ID, v)
e.sndBufUsed += len(v)
e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
@@ -2316,7 +2316,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// done yet) or the reservation was freed between the check above and
// the FindTransportEndpoint below. But rather than retry the same port
// we just skip it and move on.
- transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID())
if transEP == nil {
// ReservePort failed but there is no registered endpoint with
// demuxer. Which indicates there is at least some endpoint that has
@@ -2385,7 +2385,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.ID
- s.route = r.Clone()
e.sndWaker.Assert()
}
}
@@ -2451,7 +2450,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Queue fin segment.
- s := newSegmentFromView(&e.route, e.ID, nil)
+ s := newOutgoingSegment(e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
// Mark endpoint as closed.
@@ -2723,7 +2722,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
}
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -3082,9 +3081,9 @@ func (e *endpoint) initHardwareGSO() {
}
func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.route.HasHardwareGSOCapability() {
e.initHardwareGSO()
- } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ } else if e.route.HasSoftwareGSOCapability() {
e.gso = &stack.GSO{
MaxSize: e.route.GSOMaxSize(),
Type: stack.GSOSW,
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b25431467..2bcc5e1c2 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() {
switch {
case epState == StateInitial || epState == StateBound:
case epState.connected() || epState.handshake():
- if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
- if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
+ if !e.route.HasSaveRestoreCapability() {
+ if !e.route.HasDisconncetOkCapability() {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
}
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 070b634b4..0664789da 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -30,6 +30,8 @@ import (
// The canonical way of using it is to pass the Forwarder.HandlePacket function
// to stack.SetTransportProtocolHandler.
type Forwarder struct {
+ stack *stack.Stack
+
maxInFlight int
handler func(*ForwarderRequest)
@@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
rcvWnd = DefaultReceiveBufferSize
}
return &Forwarder{
+ stack: s,
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
@@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
- s := newSegment(r, id, pkt)
+func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ s := newIncomingSegment(id, pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
- if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
+ if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn {
return false
}
@@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
delete(r.forwarder.inFlight, r.segment.id)
r.forwarder.mu.Unlock()
- // If the caller requested, send a reset.
if sendReset {
- replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
+ replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */)
}
// Release all resources.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 5bce73605..2329aca4b 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- p.dispatcher.queuePacket(r, ep, id, pkt)
+func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ p.dispatcher.queuePacket(ep, id, pkt)
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
@@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
- s := newSegment(r, id, pkt)
+func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ s := newIncomingSegment(id, pkt)
defer s.decRef()
- if !s.parse() || !s.csumValid {
+ if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid {
return stack.UnknownDestinationPacketMalformed
}
if !s.flagIsSet(header.TCPFlagRst) {
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ replyWithReset(p.stack, s, stack.DefaultTOS, 0)
}
return stack.UnknownDestinationPacketHandled
}
// replyWithReset replies to the given segment with a reset segment.
-func replyWithReset(s *segment, tos, ttl uint8) {
+//
+// If the passed TTL is 0, then the route's default TTL will be used.
+func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error {
+ route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ route.ResolveWith(s.remoteLinkAddr)
+
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
@@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) {
flags |= header.TCPFlagAck
ack = s.sequenceNumber.Add(s.logicalLen())
}
- sendTCP(&s.route, tcpFields{
+
+ if ttl == 0 {
+ ttl = route.DefaultTTL()
+ }
+
+ return sendTCP(&route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 1f9c5cf50..2091989cc 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -45,9 +46,18 @@ type segment struct {
ep *endpoint
qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
- route stack.Route `state:"manual"`
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- hdr header.TCP
+
+ // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of
+ // individual members for link/network packet info.
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ netProto tcpip.NetworkProtocolNumber
+ nicID tcpip.NICID
+ remoteLinkAddr tcpip.LinkAddress
+
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -76,11 +86,16 @@ type segment struct {
acked bool
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+ netHdr := pkt.Network()
s := &segment{
- refCnt: 1,
- id: id,
- route: r.Clone(),
+ refCnt: 1,
+ id: id,
+ srcAddr: netHdr.SourceAddress(),
+ dstAddr: netHdr.DestinationAddress(),
+ netProto: pkt.NetworkProtocolNumber,
+ nicID: pkt.NICID,
+ remoteLinkAddr: pkt.SourceLinkAddress(),
}
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
@@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB
return s
}
-func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
- route: r.Clone(),
}
s.rcvdTime = time.Now()
if len(v) != 0 {
@@ -110,7 +124,9 @@ func (s *segment) clone() *segment {
ackNumber: s.ackNumber,
flags: s.flags,
window: s.window,
- route: s.route.Clone(),
+ netProto: s.netProto,
+ nicID: s.nicID,
+ remoteLinkAddr: s.remoteLinkAddr,
viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
@@ -160,7 +176,6 @@ func (s *segment) decRef() {
panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
}
}
- s.route.Release()
}
}
@@ -198,10 +213,10 @@ func (s *segment) segMemSize() int {
//
// Returns boolean indicating if the parsing was successful.
//
-// If checksum verification is not offloaded then parse also verifies the
+// If checksum verification may not be skipped, parse also verifies the
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
-func (s *segment) parse() bool {
+func (s *segment) parse(skipChecksumValidation bool) bool {
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -220,16 +235,14 @@ func (s *segment) parse() bool {
s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
- // Query the link capabilities to decide if checksum validation is
- // required.
verifyChecksum := true
- if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
+ if skipChecksumValidation {
s.csumValid = true
verifyChecksum = false
}
if verifyChecksum {
s.csum = s.hdr.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr)))
xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index ee1ac778d..8782316f2 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -554,6 +554,11 @@ func (s *segment) StateFields() []string {
"refCnt",
"ep",
"qFlags",
+ "srcAddr",
+ "dstAddr",
+ "netProto",
+ "nicID",
+ "remoteLinkAddr",
"data",
"hdr",
"viewToDeliver",
@@ -578,29 +583,34 @@ func (s *segment) beforeSave() {}
func (s *segment) StateSave(stateSinkObject state.Sink) {
s.beforeSave()
var dataValue buffer.VectorisedView = s.saveData()
- stateSinkObject.SaveValue(4, dataValue)
+ stateSinkObject.SaveValue(9, dataValue)
var optionsValue []byte = s.saveOptions()
- stateSinkObject.SaveValue(14, optionsValue)
+ stateSinkObject.SaveValue(19, optionsValue)
var rcvdTimeValue unixTime = s.saveRcvdTime()
- stateSinkObject.SaveValue(16, rcvdTimeValue)
+ stateSinkObject.SaveValue(21, rcvdTimeValue)
var xmitTimeValue unixTime = s.saveXmitTime()
- stateSinkObject.SaveValue(17, xmitTimeValue)
+ stateSinkObject.SaveValue(22, xmitTimeValue)
stateSinkObject.Save(0, &s.segmentEntry)
stateSinkObject.Save(1, &s.refCnt)
stateSinkObject.Save(2, &s.ep)
stateSinkObject.Save(3, &s.qFlags)
- stateSinkObject.Save(5, &s.hdr)
- stateSinkObject.Save(6, &s.viewToDeliver)
- stateSinkObject.Save(7, &s.sequenceNumber)
- stateSinkObject.Save(8, &s.ackNumber)
- stateSinkObject.Save(9, &s.flags)
- stateSinkObject.Save(10, &s.window)
- stateSinkObject.Save(11, &s.csum)
- stateSinkObject.Save(12, &s.csumValid)
- stateSinkObject.Save(13, &s.parsedOptions)
- stateSinkObject.Save(15, &s.hasNewSACKInfo)
- stateSinkObject.Save(18, &s.xmitCount)
- stateSinkObject.Save(19, &s.acked)
+ stateSinkObject.Save(4, &s.srcAddr)
+ stateSinkObject.Save(5, &s.dstAddr)
+ stateSinkObject.Save(6, &s.netProto)
+ stateSinkObject.Save(7, &s.nicID)
+ stateSinkObject.Save(8, &s.remoteLinkAddr)
+ stateSinkObject.Save(10, &s.hdr)
+ stateSinkObject.Save(11, &s.viewToDeliver)
+ stateSinkObject.Save(12, &s.sequenceNumber)
+ stateSinkObject.Save(13, &s.ackNumber)
+ stateSinkObject.Save(14, &s.flags)
+ stateSinkObject.Save(15, &s.window)
+ stateSinkObject.Save(16, &s.csum)
+ stateSinkObject.Save(17, &s.csumValid)
+ stateSinkObject.Save(18, &s.parsedOptions)
+ stateSinkObject.Save(20, &s.hasNewSACKInfo)
+ stateSinkObject.Save(23, &s.xmitCount)
+ stateSinkObject.Save(24, &s.acked)
}
func (s *segment) afterLoad() {}
@@ -610,22 +620,27 @@ func (s *segment) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(1, &s.refCnt)
stateSourceObject.Load(2, &s.ep)
stateSourceObject.Load(3, &s.qFlags)
- stateSourceObject.Load(5, &s.hdr)
- stateSourceObject.Load(6, &s.viewToDeliver)
- stateSourceObject.Load(7, &s.sequenceNumber)
- stateSourceObject.Load(8, &s.ackNumber)
- stateSourceObject.Load(9, &s.flags)
- stateSourceObject.Load(10, &s.window)
- stateSourceObject.Load(11, &s.csum)
- stateSourceObject.Load(12, &s.csumValid)
- stateSourceObject.Load(13, &s.parsedOptions)
- stateSourceObject.Load(15, &s.hasNewSACKInfo)
- stateSourceObject.Load(18, &s.xmitCount)
- stateSourceObject.Load(19, &s.acked)
- stateSourceObject.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) })
- stateSourceObject.LoadValue(14, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) })
- stateSourceObject.LoadValue(16, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) })
- stateSourceObject.LoadValue(17, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) })
+ stateSourceObject.Load(4, &s.srcAddr)
+ stateSourceObject.Load(5, &s.dstAddr)
+ stateSourceObject.Load(6, &s.netProto)
+ stateSourceObject.Load(7, &s.nicID)
+ stateSourceObject.Load(8, &s.remoteLinkAddr)
+ stateSourceObject.Load(10, &s.hdr)
+ stateSourceObject.Load(11, &s.viewToDeliver)
+ stateSourceObject.Load(12, &s.sequenceNumber)
+ stateSourceObject.Load(13, &s.ackNumber)
+ stateSourceObject.Load(14, &s.flags)
+ stateSourceObject.Load(15, &s.window)
+ stateSourceObject.Load(16, &s.csum)
+ stateSourceObject.Load(17, &s.csumValid)
+ stateSourceObject.Load(18, &s.parsedOptions)
+ stateSourceObject.Load(20, &s.hasNewSACKInfo)
+ stateSourceObject.Load(23, &s.xmitCount)
+ stateSourceObject.Load(24, &s.acked)
+ stateSourceObject.LoadValue(9, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(19, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) })
+ stateSourceObject.LoadValue(21, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(22, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) })
}
func (q *segmentQueue) StateTypeName() string {