summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go147
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go7
2 files changed, 69 insertions, 85 deletions
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 047704c80..43b76bee5 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -66,89 +67,68 @@ func (q *epQueue) empty() bool {
// processor is responsible for processing packets queued to a tcp endpoint.
type processor struct {
epQ epQueue
+ sleeper sleep.Sleeper
newEndpointWaker sleep.Waker
closeWaker sleep.Waker
- id int
- wg sync.WaitGroup
-}
-
-func newProcessor(id int) *processor {
- p := &processor{
- id: id,
- }
- p.wg.Add(1)
- go p.handleSegments()
- return p
}
func (p *processor) close() {
p.closeWaker.Assert()
}
-func (p *processor) wait() {
- p.wg.Wait()
-}
-
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
p.newEndpointWaker.Assert()
}
-func (p *processor) handleSegments() {
- const newEndpointWaker = 1
- const closeWaker = 2
- s := sleep.Sleeper{}
- s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
- s.AddWaker(&p.closeWaker, closeWaker)
- defer s.Done()
+const (
+ newEndpointWaker = 1
+ closeWaker = 2
+)
+
+func (p *processor) start(wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer p.sleeper.Done()
+
for {
- id, ok := s.Fetch(true)
- if ok && id == closeWaker {
- p.wg.Done()
- return
+ if id, _ := p.sleeper.Fetch(true); id == closeWaker {
+ break
}
- for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ for {
+ ep := p.epQ.dequeue()
+ if ep == nil {
+ break
+ }
if ep.segmentQueue.empty() {
continue
}
- // If socket has transitioned out of connected state
- // then just let the worker handle the packet.
+ // If socket has transitioned out of connected state then just let the
+ // worker handle the packet.
//
- // NOTE: We read this outside of e.mu lock which means
- // that by the time we get to handleSegments the
- // endpoint may not be in ESTABLISHED. But this should
- // be fine as all normal shutdown states are handled by
- // handleSegments and if the endpoint moves to a
- // CLOSED/ERROR state then handleSegments is a noop.
- if ep.EndpointState() != StateEstablished {
- ep.newSegmentWaker.Assert()
- continue
- }
-
- if !ep.mu.TryLock() {
- ep.newSegmentWaker.Assert()
- continue
- }
- // If the endpoint is in a connected state then we do
- // direct delivery to ensure low latency and avoid
- // scheduler interactions.
- if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
- // Send any active resets if required.
- if err != nil {
+ // NOTE: We read this outside of e.mu lock which means that by the time
+ // we get to handleSegments the endpoint may not be in ESTABLISHED. But
+ // this should be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a CLOSED/ERROR state
+ // then handleSegments is a noop.
+ if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
+ // If the endpoint is in a connected state then we do direct delivery
+ // to ensure low latency and avoid scheduler interactions.
+ switch err := ep.handleSegments(true /* fastPath */); {
+ case err != nil:
+ // Send any active resets if required.
ep.resetConnectionLocked(err)
+ fallthrough
+ case ep.EndpointState() == StateClose:
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ case !ep.segmentQueue.empty():
+ p.epQ.enqueue(ep)
}
- ep.notifyProtocolGoroutine(notifyTickleWorker)
ep.mu.Unlock()
- continue
- }
-
- if !ep.segmentQueue.empty() {
- p.epQ.enqueue(ep)
+ } else {
+ ep.newSegmentWaker.Assert()
}
-
- ep.mu.Unlock()
}
}
}
@@ -159,31 +139,36 @@ func (p *processor) handleSegments() {
// hash of the endpoint id to ensure that delivery for the same endpoint happens
// in-order.
type dispatcher struct {
- processors []*processor
+ processors []processor
seed uint32
-}
-
-func newDispatcher(nProcessors int) *dispatcher {
- processors := []*processor{}
- for i := 0; i < nProcessors; i++ {
- processors = append(processors, newProcessor(i))
- }
- return &dispatcher{
- processors: processors,
- seed: generateRandUint32(),
+ wg sync.WaitGroup
+}
+
+func (d *dispatcher) init(nProcessors int) {
+ d.close()
+ d.wait()
+ d.processors = make([]processor, nProcessors)
+ d.seed = generateRandUint32()
+ for i := range d.processors {
+ p := &d.processors[i]
+ p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ p.sleeper.AddWaker(&p.closeWaker, closeWaker)
+ d.wg.Add(1)
+ // NB: sleeper-waker registration must happen synchronously to avoid races
+ // with `close`. It's possible to pull all this logic into `start`, but
+ // that results in a heap-allocated function literal.
+ go p.start(&d.wg)
}
}
func (d *dispatcher) close() {
- for _, p := range d.processors {
- p.close()
+ for i := range d.processors {
+ d.processors[i].close()
}
}
func (d *dispatcher) wait() {
- for _, p := range d.processors {
- p.wait()
- }
+ d.wg.Wait()
}
func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
@@ -231,20 +216,18 @@ func generateRandUint32() uint32 {
if _, err := rand.Read(b); err != nil {
panic(err)
}
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return binary.LittleEndian.Uint32(b)
}
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
- payload := []byte{
- byte(id.LocalPort),
- byte(id.LocalPort >> 8),
- byte(id.RemotePort),
- byte(id.RemotePort >> 8)}
+ var payload [4]byte
+ binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
h := jenkins.Sum32(d.seed)
- h.Write(payload)
+ h.Write(payload[:])
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
- return d.processors[h.Sum32()%uint32(len(d.processors))]
+ return &d.processors[h.Sum32()%uint32(len(d.processors))]
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index f2ae6ce50..b34e47bbd 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -174,7 +174,7 @@ type protocol struct {
maxRetries uint32
synRcvdCount synRcvdCounter
synRetries uint8
- dispatcher *dispatcher
+ dispatcher dispatcher
}
// Number returns the tcp protocol number.
@@ -515,7 +515,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
- return &protocol{
+ p := protocol{
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultSendBufferSize,
@@ -531,10 +531,11 @@ func NewProtocol() stack.TransportProtocol {
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
- dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
}
+ p.dispatcher.init(runtime.GOMAXPROCS(0))
+ return &p
}