diff options
Diffstat (limited to 'pkg')
26 files changed, 3139 insertions, 0 deletions
diff --git a/pkg/sentry/kernel/memevent/memevent_state_autogen.go b/pkg/sentry/kernel/memevent/memevent_state_autogen.go new file mode 100755 index 000000000..8bfbba80f --- /dev/null +++ b/pkg/sentry/kernel/memevent/memevent_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package memevent + diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go new file mode 100755 index 000000000..b0d98e7f0 --- /dev/null +++ b/pkg/sentry/kernel/memevent/memory_events.go @@ -0,0 +1,111 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memevent implements the memory usage events controller, which +// periodically emits events via the eventchannel. +package memevent + +import ( + "sync" + "time" + + "gvisor.dev/gvisor/pkg/eventchannel" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/metric" + "gvisor.dev/gvisor/pkg/sentry/kernel" + pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.") +var totalEvents = metric.MustCreateNewUint64Metric("/memory_events/events", false /*sync*/, "Total number of memory events emitted.") + +// MemoryEvents describes the configuration for the global memory event emitter. +type MemoryEvents struct { + k *kernel.Kernel + + // The period is how often to emit an event. The memory events goroutine + // will ensure a minimum of one event is emitted per this period, regardless + // how of much memory usage has changed. + period time.Duration + + // Writing to this channel indicates the memory goroutine should stop. + stop chan struct{} + + // done is used to signal when the memory event goroutine has exited. + done sync.WaitGroup +} + +// New creates a new MemoryEvents. +func New(k *kernel.Kernel, period time.Duration) *MemoryEvents { + return &MemoryEvents{ + k: k, + period: period, + stop: make(chan struct{}), + } +} + +// Stop stops the memory usage events emitter goroutine. Stop must not be called +// concurrently with Start and may only be called once. +func (m *MemoryEvents) Stop() { + close(m.stop) + m.done.Wait() +} + +// Start starts the memory usage events emitter goroutine. Start must not be +// called concurrently with Stop and may only be called once. +func (m *MemoryEvents) Start() { + if m.period == 0 { + return + } + m.done.Add(1) + go m.run() // S/R-SAFE: doesn't interact with saved state. +} + +func (m *MemoryEvents) run() { + defer m.done.Done() + + // Emit the first event immediately on startup. + totalTicks.Increment() + m.emit() + + ticker := time.NewTicker(m.period) + defer ticker.Stop() + + for { + select { + case <-m.stop: + return + case <-ticker.C: + totalTicks.Increment() + m.emit() + } + } +} + +func (m *MemoryEvents) emit() { + totalPlatform, err := m.k.MemoryFile().TotalUsage() + if err != nil { + log.Warningf("Failed to fetch memory usage for memory events: %v", err) + return + } + snapshot, _ := usage.MemoryAccounting.Copy() + total := totalPlatform + snapshot.Mapped + + totalEvents.Increment() + eventchannel.Emit(&pb.MemoryUsageEvent{ + Mapped: snapshot.Mapped, + Total: total, + }) +} diff --git a/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go b/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go new file mode 100755 index 000000000..f8b857fa9 --- /dev/null +++ b/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go @@ -0,0 +1,88 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/kernel/memevent/memory_events.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type MemoryUsageEvent struct { + Total uint64 `protobuf:"varint,1,opt,name=total,proto3" json:"total,omitempty"` + Mapped uint64 `protobuf:"varint,2,opt,name=mapped,proto3" json:"mapped,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MemoryUsageEvent) Reset() { *m = MemoryUsageEvent{} } +func (m *MemoryUsageEvent) String() string { return proto.CompactTextString(m) } +func (*MemoryUsageEvent) ProtoMessage() {} +func (*MemoryUsageEvent) Descriptor() ([]byte, []int) { + return fileDescriptor_cd85fc8d1130e4b0, []int{0} +} + +func (m *MemoryUsageEvent) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MemoryUsageEvent.Unmarshal(m, b) +} +func (m *MemoryUsageEvent) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MemoryUsageEvent.Marshal(b, m, deterministic) +} +func (m *MemoryUsageEvent) XXX_Merge(src proto.Message) { + xxx_messageInfo_MemoryUsageEvent.Merge(m, src) +} +func (m *MemoryUsageEvent) XXX_Size() int { + return xxx_messageInfo_MemoryUsageEvent.Size(m) +} +func (m *MemoryUsageEvent) XXX_DiscardUnknown() { + xxx_messageInfo_MemoryUsageEvent.DiscardUnknown(m) +} + +var xxx_messageInfo_MemoryUsageEvent proto.InternalMessageInfo + +func (m *MemoryUsageEvent) GetTotal() uint64 { + if m != nil { + return m.Total + } + return 0 +} + +func (m *MemoryUsageEvent) GetMapped() uint64 { + if m != nil { + return m.Mapped + } + return 0 +} + +func init() { + proto.RegisterType((*MemoryUsageEvent)(nil), "gvisor.MemoryUsageEvent") +} + +func init() { + proto.RegisterFile("pkg/sentry/kernel/memevent/memory_events.proto", fileDescriptor_cd85fc8d1130e4b0) +} + +var fileDescriptor_cd85fc8d1130e4b0 = []byte{ + // 128 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2b, 0xc8, 0x4e, 0xd7, + 0x2f, 0x4e, 0xcd, 0x2b, 0x29, 0xaa, 0xd4, 0xcf, 0x4e, 0x2d, 0xca, 0x4b, 0xcd, 0xd1, 0xcf, 0x4d, + 0xcd, 0x4d, 0x2d, 0x4b, 0xcd, 0x2b, 0x01, 0x31, 0xf2, 0x8b, 0x2a, 0xe3, 0xc1, 0x9c, 0x62, 0xbd, + 0x82, 0xa2, 0xfc, 0x92, 0x7c, 0x21, 0xb6, 0xf4, 0xb2, 0xcc, 0xe2, 0xfc, 0x22, 0x25, 0x07, 0x2e, + 0x01, 0x5f, 0xb0, 0x74, 0x68, 0x71, 0x62, 0x7a, 0xaa, 0x2b, 0x48, 0x89, 0x90, 0x08, 0x17, 0x6b, + 0x49, 0x7e, 0x49, 0x62, 0x8e, 0x04, 0xa3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x84, 0x23, 0x24, 0xc6, + 0xc5, 0x96, 0x9b, 0x58, 0x50, 0x90, 0x9a, 0x22, 0xc1, 0x04, 0x16, 0x86, 0xf2, 0x92, 0xd8, 0xc0, + 0x06, 0x1a, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0x99, 0x31, 0x2f, 0x9d, 0x82, 0x00, 0x00, 0x00, +} diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go new file mode 100755 index 000000000..cd6ce930a --- /dev/null +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -0,0 +1,722 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gonet provides a Go net package compatible wrapper for a tcpip stack. +package gonet + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +var ( + errCanceled = errors.New("operation canceled") + errWouldBlock = errors.New("operation would block") +) + +// timeoutError is how the net package reports timeouts. +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +// A Listener is a wrapper around a tcpip endpoint that implements +// net.Listener. +type Listener struct { + stack *stack.Stack + ep tcpip.Endpoint + wq *waiter.Queue + cancel chan struct{} +} + +// NewListener creates a new Listener. +func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) { + // Create TCP endpoint, bind it, then start listening. + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) + if err != nil { + return nil, errors.New(err.String()) + } + + if err := ep.Bind(addr); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "bind", + Net: "tcp", + Addr: fullToTCPAddr(addr), + Err: errors.New(err.String()), + } + } + + if err := ep.Listen(10); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "listen", + Net: "tcp", + Addr: fullToTCPAddr(addr), + Err: errors.New(err.String()), + } + } + + return &Listener{ + stack: s, + ep: ep, + wq: &wq, + cancel: make(chan struct{}), + }, nil +} + +// Close implements net.Listener.Close. +func (l *Listener) Close() error { + l.ep.Close() + return nil +} + +// Shutdown stops the HTTP server. +func (l *Listener) Shutdown() { + l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + close(l.cancel) // broadcast cancellation +} + +// Addr implements net.Listener.Addr. +func (l *Listener) Addr() net.Addr { + a, err := l.ep.GetLocalAddress() + if err != nil { + return nil + } + return fullToTCPAddr(a) +} + +type deadlineTimer struct { + // mu protects the fields below. + mu sync.Mutex + + readTimer *time.Timer + readCancelCh chan struct{} + writeTimer *time.Timer + writeCancelCh chan struct{} +} + +func (d *deadlineTimer) init() { + d.readCancelCh = make(chan struct{}) + d.writeCancelCh = make(chan struct{}) +} + +func (d *deadlineTimer) readCancel() <-chan struct{} { + d.mu.Lock() + c := d.readCancelCh + d.mu.Unlock() + return c +} +func (d *deadlineTimer) writeCancel() <-chan struct{} { + d.mu.Lock() + c := d.writeCancelCh + d.mu.Unlock() + return c +} + +// setDeadline contains the shared logic for setting a deadline. +// +// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and +// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and +// deadlineTimer.writeTimer. +// +// setDeadline must only be called while holding d.mu. +func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { + if *timer != nil && !(*timer).Stop() { + *cancelCh = make(chan struct{}) + } + + // Create a new channel if we already closed it due to setting an already + // expired time. We won't race with the timer because we already handled + // that above. + select { + case <-*cancelCh: + *cancelCh = make(chan struct{}) + default: + } + + // "A zero value for t means I/O operations will not time out." + // - net.Conn.SetDeadline + if t.IsZero() { + return + } + + timeout := t.Sub(time.Now()) + if timeout <= 0 { + close(*cancelCh) + return + } + + // Timer.Stop returns whether or not the AfterFunc has started, but + // does not indicate whether or not it has completed. Make a copy of + // the cancel channel to prevent this code from racing with the next + // call of setDeadline replacing *cancelCh. + ch := *cancelCh + *timer = time.AfterFunc(timeout, func() { + close(ch) + }) +} + +// SetReadDeadline implements net.Conn.SetReadDeadline and +// net.PacketConn.SetReadDeadline. +func (d *deadlineTimer) SetReadDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.readCancelCh, &d.readTimer, t) + d.mu.Unlock() + return nil +} + +// SetWriteDeadline implements net.Conn.SetWriteDeadline and +// net.PacketConn.SetWriteDeadline. +func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) + d.mu.Unlock() + return nil +} + +// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. +func (d *deadlineTimer) SetDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.readCancelCh, &d.readTimer, t) + d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) + d.mu.Unlock() + return nil +} + +// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn +// interface. +type Conn struct { + deadlineTimer + + wq *waiter.Queue + ep tcpip.Endpoint + + // readMu serializes reads and implicitly protects read. + // + // Lock ordering: + // If both readMu and deadlineTimer.mu are to be used in a single + // request, readMu must be acquired before deadlineTimer.mu. + readMu sync.Mutex + + // read contains bytes that have been read from the endpoint, + // but haven't yet been returned. + read buffer.View +} + +// NewConn creates a new Conn. +func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { + c := &Conn{ + wq: wq, + ep: ep, + } + c.deadlineTimer.init() + return c +} + +// Accept implements net.Conn.Accept. +func (l *Listener) Accept() (net.Conn, error) { + n, wq, err := l.ep.Accept() + + if err == tcpip.ErrWouldBlock { + // Create wait queue entry that notifies a channel. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + l.wq.EventRegister(&waitEntry, waiter.EventIn) + defer l.wq.EventUnregister(&waitEntry) + + for { + n, wq, err = l.ep.Accept() + + if err != tcpip.ErrWouldBlock { + break + } + + select { + case <-l.cancel: + return nil, errCanceled + case <-notifyCh: + } + } + } + + if err != nil { + return nil, &net.OpError{ + Op: "accept", + Net: "tcp", + Addr: l.Addr(), + Err: errors.New(err.String()), + } + } + + return NewConn(wq, n), nil +} + +type opErrorer interface { + newOpError(op string, err error) *net.OpError +} + +// commonRead implements the common logic between net.Conn.Read and +// net.PacketConn.ReadFrom. +func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) { + select { + case <-deadline: + return nil, errorer.newOpError("read", &timeoutError{}) + default: + } + + read, _, err := ep.Read(addr) + + if err == tcpip.ErrWouldBlock { + if dontWait { + return nil, errWouldBlock + } + // Create wait queue entry that notifies a channel. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventIn) + defer wq.EventUnregister(&waitEntry) + for { + read, _, err = ep.Read(addr) + if err != tcpip.ErrWouldBlock { + break + } + select { + case <-deadline: + return nil, errorer.newOpError("read", &timeoutError{}) + case <-notifyCh: + } + } + } + + if err == tcpip.ErrClosedForReceive { + return nil, io.EOF + } + + if err != nil { + return nil, errorer.newOpError("read", errors.New(err.String())) + } + + return read, nil +} + +// Read implements net.Conn.Read. +func (c *Conn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + deadline := c.readCancel() + + numRead := 0 + for numRead != len(b) { + if len(c.read) == 0 { + var err error + c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0) + if err != nil { + if numRead != 0 { + return numRead, nil + } + return numRead, err + } + } + n := copy(b[numRead:], c.read) + c.read.TrimFront(n) + numRead += n + if len(c.read) == 0 { + c.read = nil + } + } + return numRead, nil +} + +// Write implements net.Conn.Write. +func (c *Conn) Write(b []byte) (int, error) { + deadline := c.writeCancel() + + // Check if deadlineTimer has already expired. + select { + case <-deadline: + return 0, c.newOpError("write", &timeoutError{}) + default: + } + + v := buffer.NewViewFromBytes(b) + + // We must handle two soft failure conditions simultaneously: + // 1. Write may write nothing and return tcpip.ErrWouldBlock. + // If this happens, we need to register for notifications if we have + // not already and wait to try again. + // 2. Write may write fewer than the full number of bytes and return + // without error. In this case we need to try writing the remaining + // bytes again. I do not need to register for notifications. + // + // What is more, these two soft failure conditions can be interspersed. + // There is no guarantee that all of the condition #1s will occur before + // all of the condition #2s or visa-versa. + var ( + err *tcpip.Error + nbytes int + reg bool + notifyCh chan struct{} + ) + for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { + if err == tcpip.ErrWouldBlock { + if !reg { + // Only register once. + reg = true + + // Create wait queue entry that notifies a channel. + var waitEntry waiter.Entry + waitEntry, notifyCh = waiter.NewChannelEntry(nil) + c.wq.EventRegister(&waitEntry, waiter.EventOut) + defer c.wq.EventUnregister(&waitEntry) + } else { + // Don't wait immediately after registration in case more data + // became available between when we last checked and when we setup + // the notification. + select { + case <-deadline: + return nbytes, c.newOpError("write", &timeoutError{}) + case <-notifyCh: + } + } + } + + var n int64 + var resCh <-chan struct{} + n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + nbytes += int(n) + v.TrimFront(int(n)) + + if resCh != nil { + select { + case <-deadline: + return nbytes, c.newOpError("write", &timeoutError{}) + case <-resCh: + } + + n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + nbytes += int(n) + v.TrimFront(int(n)) + } + } + + if err == nil { + return nbytes, nil + } + + return nbytes, c.newOpError("write", errors.New(err.String())) +} + +// Close implements net.Conn.Close. +func (c *Conn) Close() error { + c.ep.Close() + return nil +} + +// CloseRead shuts down the reading side of the TCP connection. Most callers +// should just use Close. +// +// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. +func (c *Conn) CloseRead() error { + if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { + return c.newOpError("close", errors.New(terr.String())) + } + return nil +} + +// CloseWrite shuts down the writing side of the TCP connection. Most callers +// should just use Close. +// +// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. +func (c *Conn) CloseWrite() error { + if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { + return c.newOpError("close", errors.New(terr.String())) + } + return nil +} + +// LocalAddr implements net.Conn.LocalAddr. +func (c *Conn) LocalAddr() net.Addr { + a, err := c.ep.GetLocalAddress() + if err != nil { + return nil + } + return fullToTCPAddr(a) +} + +// RemoteAddr implements net.Conn.RemoteAddr. +func (c *Conn) RemoteAddr() net.Addr { + a, err := c.ep.GetRemoteAddress() + if err != nil { + return nil + } + return fullToTCPAddr(a) +} + +func (c *Conn) newOpError(op string, err error) *net.OpError { + return &net.OpError{ + Op: op, + Net: "tcp", + Source: c.LocalAddr(), + Addr: c.RemoteAddr(), + Err: err, + } +} + +func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr { + return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} +} + +func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { + return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} +} + +// DialTCP creates a new TCP Conn connected to the specified address. +func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { + return DialContextTCP(context.Background(), s, addr, network) +} + +// DialContextTCP creates a new TCP Conn connected to the specified address +// with the option of adding cancellation and timeouts. +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { + // Create TCP endpoint, then connect. + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) + if err != nil { + return nil, errors.New(err.String()) + } + + // Create wait queue entry that notifies a channel. + // + // We do this unconditionally as Connect will always return an error. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventOut) + defer wq.EventUnregister(&waitEntry) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + err = ep.Connect(addr) + if err == tcpip.ErrConnectStarted { + select { + case <-ctx.Done(): + ep.Close() + return nil, ctx.Err() + case <-notifyCh: + } + + err = ep.GetSockOpt(tcpip.ErrorOption{}) + } + if err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "connect", + Net: "tcp", + Addr: fullToTCPAddr(addr), + Err: errors.New(err.String()), + } + } + + return NewConn(&wq, ep), nil +} + +// A PacketConn is a wrapper around a tcpip endpoint that implements +// net.PacketConn. +type PacketConn struct { + deadlineTimer + + stack *stack.Stack + ep tcpip.Endpoint + wq *waiter.Queue +} + +// DialUDP creates a new PacketConn. +// +// If laddr is nil, a local address is automatically chosen. +// +// If raddr is nil, the PacketConn is left unconnected. +func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) + if err != nil { + return nil, errors.New(err.String()) + } + + if laddr != nil { + if err := ep.Bind(*laddr); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "bind", + Net: "udp", + Addr: fullToUDPAddr(*laddr), + Err: errors.New(err.String()), + } + } + } + + c := PacketConn{ + stack: s, + ep: ep, + wq: &wq, + } + c.deadlineTimer.init() + + if raddr != nil { + if err := c.ep.Connect(*raddr); err != nil { + c.ep.Close() + return nil, &net.OpError{ + Op: "connect", + Net: "udp", + Addr: fullToUDPAddr(*raddr), + Err: errors.New(err.String()), + } + } + } + + return &c, nil +} + +func (c *PacketConn) newOpError(op string, err error) *net.OpError { + return c.newRemoteOpError(op, nil, err) +} + +func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { + return &net.OpError{ + Op: op, + Net: "udp", + Source: c.LocalAddr(), + Addr: remote, + Err: err, + } +} + +// RemoteAddr implements net.Conn.RemoteAddr. +func (c *PacketConn) RemoteAddr() net.Addr { + a, err := c.ep.GetRemoteAddress() + if err != nil { + return nil + } + return fullToTCPAddr(a) +} + +// Read implements net.Conn.Read +func (c *PacketConn) Read(b []byte) (int, error) { + bytesRead, _, err := c.ReadFrom(b) + return bytesRead, err +} + +// ReadFrom implements net.PacketConn.ReadFrom. +func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + deadline := c.readCancel() + + var addr tcpip.FullAddress + read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false) + if err != nil { + return 0, nil, err + } + + return copy(b, read), fullToUDPAddr(addr), nil +} + +func (c *PacketConn) Write(b []byte) (int, error) { + return c.WriteTo(b, nil) +} + +// WriteTo implements net.PacketConn.WriteTo. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + deadline := c.writeCancel() + + // Check if deadline has already expired. + select { + case <-deadline: + return 0, c.newRemoteOpError("write", addr, &timeoutError{}) + default: + } + + // If we're being called by Write, there is no addr + wopts := tcpip.WriteOptions{} + if addr != nil { + ua := addr.(*net.UDPAddr) + wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} + } + + v := buffer.NewView(len(b)) + copy(v, b) + + n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts) + if resCh != nil { + select { + case <-deadline: + return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) + case <-resCh: + } + + n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + } + + if err == tcpip.ErrWouldBlock { + // Create wait queue entry that notifies a channel. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.wq.EventRegister(&waitEntry, waiter.EventOut) + defer c.wq.EventUnregister(&waitEntry) + for { + select { + case <-deadline: + return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) + case <-notifyCh: + } + + n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + if err != tcpip.ErrWouldBlock { + break + } + } + } + + if err == nil { + return int(n), nil + } + + return int(n), c.newRemoteOpError("write", addr, errors.New(err.String())) +} + +// Close implements net.PacketConn.Close. +func (c *PacketConn) Close() error { + c.ep.Close() + return nil +} + +// LocalAddr implements net.PacketConn.LocalAddr. +func (c *PacketConn) LocalAddr() net.Addr { + a, err := c.ep.GetLocalAddress() + if err != nil { + return nil + } + return fullToUDPAddr(a) +} diff --git a/pkg/tcpip/adapters/gonet/gonet_state_autogen.go b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go new file mode 100755 index 000000000..9b87956fd --- /dev/null +++ b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package gonet + diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go new file mode 100755 index 000000000..445b22c17 --- /dev/null +++ b/pkg/tcpip/link/muxed/injectable.go @@ -0,0 +1,137 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package muxed provides a muxed link endpoints. +package muxed + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// InjectableEndpoint is an injectable multi endpoint. The endpoint has +// trivial routing rules that determine which InjectableEndpoint a given packet +// will be written to. Note that HandleLocal works differently for this +// endpoint (see WritePacket). +type InjectableEndpoint struct { + routes map[tcpip.Address]stack.InjectableLinkEndpoint + dispatcher stack.NetworkDispatcher +} + +// MTU implements stack.LinkEndpoint. +func (m *InjectableEndpoint) MTU() uint32 { + minMTU := ^uint32(0) + for _, endpoint := range m.routes { + if endpointMTU := endpoint.MTU(); endpointMTU < minMTU { + minMTU = endpointMTU + } + } + return minMTU +} + +// Capabilities implements stack.LinkEndpoint. +func (m *InjectableEndpoint) Capabilities() stack.LinkEndpointCapabilities { + minCapabilities := stack.LinkEndpointCapabilities(^uint(0)) + for _, endpoint := range m.routes { + minCapabilities &= endpoint.Capabilities() + } + return minCapabilities +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (m *InjectableEndpoint) MaxHeaderLength() uint16 { + minHeaderLen := ^uint16(0) + for _, endpoint := range m.routes { + if headerLen := endpoint.MaxHeaderLength(); headerLen < minHeaderLen { + minHeaderLen = headerLen + } + } + return minHeaderLen +} + +// LinkAddress implements stack.LinkEndpoint. +func (m *InjectableEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// Attach implements stack.LinkEndpoint. +func (m *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + for _, endpoint := range m.routes { + endpoint.Attach(dispatcher) + } + m.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint. +func (m *InjectableEndpoint) IsAttached() bool { + return m.dispatcher != nil +} + +// InjectInbound implements stack.InjectableLinkEndpoint. +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) +} + +// WritePackets writes outbound packets to the appropriate +// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if +// r.RemoteAddress has a route registered in this endpoint. +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + endpoint, ok := m.routes[r.RemoteAddress] + if !ok { + return 0, tcpip.ErrNoRoute + } + return endpoint.WritePackets(r, gso, pkts, protocol) +} + +// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint +// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a +// route registered in this endpoint. +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + if endpoint, ok := m.routes[r.RemoteAddress]; ok { + return endpoint.WritePacket(r, gso, protocol, pkt) + } + return tcpip.ErrNoRoute +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // WriteRawPacket doesn't get a route or network address, so there's + // nowhere to write this. + return tcpip.ErrNoRoute +} + +// InjectOutbound writes outbound packets to the appropriate +// LinkInjectableEndpoint based on the dest address. +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { + endpoint, ok := m.routes[dest] + if !ok { + return tcpip.ErrNoRoute + } + return endpoint.InjectOutbound(dest, packet) +} + +// Wait implements stack.LinkEndpoint.Wait. +func (m *InjectableEndpoint) Wait() { + for _, ep := range m.routes { + ep.Wait() + } +} + +// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. +func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { + return &InjectableEndpoint{ + routes: routes, + } +} diff --git a/pkg/tcpip/link/muxed/muxed_state_autogen.go b/pkg/tcpip/link/muxed/muxed_state_autogen.go new file mode 100755 index 000000000..e3330c0da --- /dev/null +++ b/pkg/tcpip/link/muxed/muxed_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package muxed + diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe.go b/pkg/tcpip/link/sharedmem/pipe/pipe.go new file mode 100755 index 000000000..74c9f0311 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe.go @@ -0,0 +1,78 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pipe implements a shared memory ring buffer on which a single reader +// and a single writer can operate (read/write) concurrently. The ring buffer +// allows for data of different sizes to be written, and preserves the boundary +// of the written data. +// +// Example usage is as follows: +// +// wb := t.Push(20) +// // Write data to wb. +// t.Flush() +// +// rb := r.Pull() +// // Do something with data in rb. +// t.Flush() +package pipe + +import ( + "math" +) + +const ( + jump uint64 = math.MaxUint32 + 1 + offsetMask uint64 = math.MaxUint32 + revolutionMask uint64 = ^offsetMask + + sizeOfSlotHeader = 8 // sizeof(uint64) + slotFree uint64 = 1 << 63 + slotSizeMask uint64 = math.MaxUint32 +) + +// payloadToSlotSize calculates the total size of a slot based on its payload +// size. The total size is the header size, plus the payload size, plus padding +// if necessary to make the total size a multiple of sizeOfSlotHeader. +func payloadToSlotSize(payloadSize uint64) uint64 { + s := sizeOfSlotHeader + payloadSize + return (s + sizeOfSlotHeader - 1) &^ (sizeOfSlotHeader - 1) +} + +// slotToPayloadSize calculates the payload size of a slot based on the total +// size of the slot. This is only meant to be used when creating slots that +// don't carry information (e.g., free slots or wrap slots). +func slotToPayloadSize(offset uint64) uint64 { + return offset - sizeOfSlotHeader +} + +// pipe is a basic data structure used by both (transmit & receive) ends of a +// pipe. Indices into this pipe are split into two fields: offset, which counts +// the number of bytes from the beginning of the buffer, and revolution, which +// counts the number of times the index has wrapped around. +type pipe struct { + buffer []byte +} + +// init initializes the pipe buffer such that its size is a multiple of the size +// of the slot header. +func (p *pipe) init(b []byte) { + p.buffer = b[:len(b)&^(sizeOfSlotHeader-1)] +} + +// data returns a section of the buffer starting at the given index (which may +// include revolution information) and with the given size. +func (p *pipe) data(idx uint64, size uint64) []byte { + return p.buffer[(idx&offsetMask)+sizeOfSlotHeader:][:size] +} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go new file mode 100755 index 000000000..c7c7c21b3 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package pipe + diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go new file mode 100755 index 000000000..62d17029e --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "sync/atomic" + "unsafe" +) + +func (p *pipe) write(idx uint64, v uint64) { + ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) + *ptr = v +} + +func (p *pipe) writeAtomic(idx uint64, v uint64) { + ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) + atomic.StoreUint64(ptr, v) +} + +func (p *pipe) readAtomic(idx uint64) uint64 { + ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) + return atomic.LoadUint64(ptr) +} diff --git a/pkg/tcpip/link/sharedmem/pipe/rx.go b/pkg/tcpip/link/sharedmem/pipe/rx.go new file mode 100755 index 000000000..f22e533ac --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/rx.go @@ -0,0 +1,93 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +// Rx is the receive side of the shared memory ring buffer. +type Rx struct { + p pipe + + tail uint64 + head uint64 +} + +// Init initializes the receive end of the pipe. In the initial state, the next +// slot to be inspected is the very first one. +func (r *Rx) Init(b []byte) { + r.p.init(b) + r.tail = 0xfffffffe * jump + r.head = r.tail +} + +// Pull reads the next buffer from the pipe, returning nil if there isn't one +// currently available. +// +// The returned slice is available until Flush() is next called. After that, it +// must not be touched. +func (r *Rx) Pull() []byte { + if r.head == r.tail+jump { + // We've already pulled the whole pipe. + return nil + } + + header := r.p.readAtomic(r.head) + if header&slotFree != 0 { + // The next slot is free, we can't pull it yet. + return nil + } + + payloadSize := header & slotSizeMask + newHead := r.head + payloadToSlotSize(payloadSize) + headWrap := (r.head & revolutionMask) | uint64(len(r.p.buffer)) + + // Check if this is a wrapping slot. If that's the case, it carries no + // data, so we just skip it and try again from the first slot. + if int64(newHead-headWrap) >= 0 { + if int64(newHead-headWrap) > int64(jump) || newHead&offsetMask != 0 { + return nil + } + + if r.tail == r.head { + // If this is the first pull since the last Flush() + // call, we flush the state so that the sender can use + // this space if it needs to. + r.p.writeAtomic(r.head, slotFree|slotToPayloadSize(newHead-r.head)) + r.tail = newHead + } + + r.head = newHead + return r.Pull() + } + + // Grab the buffer before updating r.head. + b := r.p.data(r.head, payloadSize) + r.head = newHead + return b +} + +// Flush tells the transmitter that all buffers pulled since the last Flush() +// have been used, so the transmitter is free to used their slots for further +// transmission. +func (r *Rx) Flush() { + if r.head == r.tail { + return + } + r.p.writeAtomic(r.tail, slotFree|slotToPayloadSize(r.head-r.tail)) + r.tail = r.head +} + +// Bytes returns the byte slice on which the pipe operates. +func (r *Rx) Bytes() []byte { + return r.p.buffer +} diff --git a/pkg/tcpip/link/sharedmem/pipe/tx.go b/pkg/tcpip/link/sharedmem/pipe/tx.go new file mode 100755 index 000000000..9841eb231 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/tx.go @@ -0,0 +1,161 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +// Tx is the transmit side of the shared memory ring buffer. +type Tx struct { + p pipe + maxPayloadSize uint64 + + head uint64 + tail uint64 + next uint64 + + tailHeader uint64 +} + +// Init initializes the transmit end of the pipe. In the initial state, the next +// slot to be written is the very first one, and the transmitter has the whole +// ring buffer available to it. +func (t *Tx) Init(b []byte) { + t.p.init(b) + // maxPayloadSize excludes the header of the payload, and the header + // of the wrapping message. + t.maxPayloadSize = uint64(len(t.p.buffer)) - 2*sizeOfSlotHeader + t.tail = 0xfffffffe * jump + t.next = t.tail + t.head = t.tail + jump + t.p.write(t.tail, slotFree) +} + +// Capacity determines how many records of the given size can be written to the +// pipe before it fills up. +func (t *Tx) Capacity(recordSize uint64) uint64 { + available := uint64(len(t.p.buffer)) - sizeOfSlotHeader + entryLen := payloadToSlotSize(recordSize) + return available / entryLen +} + +// Push reserves "payloadSize" bytes for transmission in the pipe. The caller +// populates the returned slice with the data to be transferred and enventually +// calls Flush() to make the data visible to the reader, or Abort() to make the +// pipe forget all Push() calls since the last Flush(). +// +// The returned slice is available until Flush() or Abort() is next called. +// After that, it must not be touched. +func (t *Tx) Push(payloadSize uint64) []byte { + // Fail request if we know we will never have enough room. + if payloadSize > t.maxPayloadSize { + return nil + } + + totalLen := payloadToSlotSize(payloadSize) + newNext := t.next + totalLen + nextWrap := (t.next & revolutionMask) | uint64(len(t.p.buffer)) + if int64(newNext-nextWrap) >= 0 { + // The new buffer would overflow the pipe, so we push a wrapping + // slot, then try to add the actual slot to the front of the + // pipe. + newNext = (newNext & revolutionMask) + jump + wrappingPayloadSize := slotToPayloadSize(newNext - t.next) + if !t.reclaim(newNext) { + return nil + } + + oldNext := t.next + t.next = newNext + if oldNext != t.tail { + t.p.write(oldNext, wrappingPayloadSize) + } else { + t.tailHeader = wrappingPayloadSize + t.Flush() + } + + newNext += totalLen + } + + // Check that we have enough room for the buffer. + if !t.reclaim(newNext) { + return nil + } + + if t.next != t.tail { + t.p.write(t.next, payloadSize) + } else { + t.tailHeader = payloadSize + } + + // Grab the buffer before updating t.next. + b := t.p.data(t.next, payloadSize) + t.next = newNext + + return b +} + +// reclaim attempts to advance the head until at least newNext. If the head is +// already at or beyond newNext, nothing happens and true is returned; otherwise +// it tries to reclaim slots that have already been consumed by the receive end +// of the pipe (they will be marked as free) and returns a boolean indicating +// whether it was successful in reclaiming enough slots. +func (t *Tx) reclaim(newNext uint64) bool { + for int64(newNext-t.head) > 0 { + // Can't reclaim if slot is not free. + header := t.p.readAtomic(t.head) + if header&slotFree == 0 { + return false + } + + payloadSize := header & slotSizeMask + newHead := t.head + payloadToSlotSize(payloadSize) + + // Check newHead is within bounds and valid. + if int64(newHead-t.tail) > int64(jump) || newHead&offsetMask >= uint64(len(t.p.buffer)) { + return false + } + + t.head = newHead + } + + return true +} + +// Abort causes all Push() calls since the last Flush() to be forgotten and +// therefore they will not be made visible to the receiver. +func (t *Tx) Abort() { + t.next = t.tail +} + +// Flush causes all buffers pushed since the last Flush() [or Abort(), whichever +// is the most recent] to be made visible to the receiver. +func (t *Tx) Flush() { + if t.next == t.tail { + // Nothing to do if there are no pushed buffers. + return + } + + if t.next != t.head { + // The receiver will spin in t.next, so we must make sure that + // the slotFree bit is set. + t.p.write(t.next, slotFree) + } + + t.p.writeAtomic(t.tail, t.tailHeader) + t.tail = t.next +} + +// Bytes returns the byte slice on which the pipe operates. +func (t *Tx) Bytes() []byte { + return t.p.buffer +} diff --git a/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go new file mode 100755 index 000000000..eec17d734 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package queue + diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go new file mode 100755 index 000000000..696e6c9e5 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/rx.go @@ -0,0 +1,221 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package queue provides the implementation of transmit and receive queues +// based on shared memory ring buffers. +package queue + +import ( + "encoding/binary" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" +) + +const ( + // Offsets within a posted buffer. + postedOffset = 0 + postedSize = 8 + postedRemainingInGroup = 12 + postedUserData = 16 + postedID = 24 + + sizeOfPostedBuffer = 32 + + // Offsets within a received packet header. + consumedPacketSize = 0 + consumedPacketReserved = 4 + + sizeOfConsumedPacketHeader = 8 + + // Offsets within a consumed buffer. + consumedOffset = 0 + consumedSize = 8 + consumedUserData = 12 + consumedID = 20 + + sizeOfConsumedBuffer = 28 + + // The following are the allowed states of the shared data area. + eventFDUninitialized = 0 + eventFDDisabled = 1 + eventFDEnabled = 2 +) + +// RxBuffer is the descriptor of a receive buffer. +type RxBuffer struct { + Offset uint64 + Size uint32 + ID uint64 + UserData uint64 +} + +// Rx is a receive queue. It is implemented with one tx and one rx pipe: the tx +// pipe is used to "post" buffers, while the rx pipe is used to receive packets +// whose contents have been written to previously posted buffers. +// +// This struct is thread-compatible. +type Rx struct { + tx pipe.Tx + rx pipe.Rx + sharedEventFDState *uint32 +} + +// Init initializes the receive queue with the given pipes, and shared state +// pointer -- the latter is used to enable/disable eventfd notifications. +func (r *Rx) Init(tx, rx []byte, sharedEventFDState *uint32) { + r.sharedEventFDState = sharedEventFDState + r.tx.Init(tx) + r.rx.Init(rx) +} + +// EnableNotification updates the shared state such that the peer will notify +// the eventfd when there are packets to be dequeued. +func (r *Rx) EnableNotification() { + atomic.StoreUint32(r.sharedEventFDState, eventFDEnabled) +} + +// DisableNotification updates the shared state such that the peer will not +// notify the eventfd. +func (r *Rx) DisableNotification() { + atomic.StoreUint32(r.sharedEventFDState, eventFDDisabled) +} + +// PostedBuffersLimit returns the maximum number of buffers that can be posted +// before the tx queue fills up. +func (r *Rx) PostedBuffersLimit() uint64 { + return r.tx.Capacity(sizeOfPostedBuffer) +} + +// PostBuffers makes the given buffers available for receiving data from the +// peer. Once they are posted, the peer is free to write to them and will +// eventually post them back for consumption. +func (r *Rx) PostBuffers(buffers []RxBuffer) bool { + for i := range buffers { + b := r.tx.Push(sizeOfPostedBuffer) + if b == nil { + r.tx.Abort() + return false + } + + pb := &buffers[i] + binary.LittleEndian.PutUint64(b[postedOffset:], pb.Offset) + binary.LittleEndian.PutUint32(b[postedSize:], pb.Size) + binary.LittleEndian.PutUint32(b[postedRemainingInGroup:], 0) + binary.LittleEndian.PutUint64(b[postedUserData:], pb.UserData) + binary.LittleEndian.PutUint64(b[postedID:], pb.ID) + } + + r.tx.Flush() + + return true +} + +// Dequeue receives buffers that have been previously posted by PostBuffers() +// and that have been filled by the peer and posted back. +// +// This is similar to append() in that new buffers are appended to "bufs", with +// reallocation only if "bufs" doesn't have enough capacity. +func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) { + for { + outBufs := bufs + + // Pull the next descriptor from the rx pipe. + b := r.rx.Pull() + if b == nil { + return bufs, 0 + } + + if len(b) < sizeOfConsumedPacketHeader { + log.Warningf("Ignoring packet header: size (%v) is less than header size (%v)", len(b), sizeOfConsumedPacketHeader) + r.rx.Flush() + continue + } + + totalDataSize := binary.LittleEndian.Uint32(b[consumedPacketSize:]) + + // Calculate the number of buffer descriptors and copy them + // over to the output. + count := (len(b) - sizeOfConsumedPacketHeader) / sizeOfConsumedBuffer + offset := sizeOfConsumedPacketHeader + buffersSize := uint32(0) + for i := count; i > 0; i-- { + s := binary.LittleEndian.Uint32(b[offset+consumedSize:]) + buffersSize += s + if buffersSize < s { + // The buffer size overflows an unsigned 32-bit + // integer, so break out and force it to be + // ignored. + totalDataSize = 1 + buffersSize = 0 + break + } + + outBufs = append(outBufs, RxBuffer{ + Offset: binary.LittleEndian.Uint64(b[offset+consumedOffset:]), + Size: s, + ID: binary.LittleEndian.Uint64(b[offset+consumedID:]), + }) + + offset += sizeOfConsumedBuffer + } + + r.rx.Flush() + + if buffersSize < totalDataSize { + // The descriptor is corrupted, ignore it. + log.Warningf("Ignoring packet: actual data size (%v) less than expected size (%v)", buffersSize, totalDataSize) + continue + } + + return outBufs, totalDataSize + } +} + +// Bytes returns the byte slices on which the queue operates. +func (r *Rx) Bytes() (tx, rx []byte) { + return r.tx.Bytes(), r.rx.Bytes() +} + +// DecodeRxBufferHeader decodes the header of a buffer posted on an rx queue. +func DecodeRxBufferHeader(b []byte) RxBuffer { + return RxBuffer{ + Offset: binary.LittleEndian.Uint64(b[postedOffset:]), + Size: binary.LittleEndian.Uint32(b[postedSize:]), + ID: binary.LittleEndian.Uint64(b[postedID:]), + UserData: binary.LittleEndian.Uint64(b[postedUserData:]), + } +} + +// RxCompletionSize returns the number of bytes needed to encode an rx +// completion containing "count" buffers. +func RxCompletionSize(count int) uint64 { + return sizeOfConsumedPacketHeader + uint64(count)*sizeOfConsumedBuffer +} + +// EncodeRxCompletion encodes an rx completion header. +func EncodeRxCompletion(b []byte, size, reserved uint32) { + binary.LittleEndian.PutUint32(b[consumedPacketSize:], size) + binary.LittleEndian.PutUint32(b[consumedPacketReserved:], reserved) +} + +// EncodeRxCompletionBuffer encodes the i-th rx completion buffer header. +func EncodeRxCompletionBuffer(b []byte, i int, rxb RxBuffer) { + b = b[RxCompletionSize(i):] + binary.LittleEndian.PutUint64(b[consumedOffset:], rxb.Offset) + binary.LittleEndian.PutUint32(b[consumedSize:], rxb.Size) + binary.LittleEndian.PutUint64(b[consumedUserData:], rxb.UserData) + binary.LittleEndian.PutUint64(b[consumedID:], rxb.ID) +} diff --git a/pkg/tcpip/link/sharedmem/queue/tx.go b/pkg/tcpip/link/sharedmem/queue/tx.go new file mode 100755 index 000000000..beffe807b --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/tx.go @@ -0,0 +1,151 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" +) + +const ( + // Offsets within a packet header. + packetID = 0 + packetSize = 8 + packetReserved = 12 + + sizeOfPacketHeader = 16 + + // Offsets with a buffer descriptor + bufferOffset = 0 + bufferSize = 8 + + sizeOfBufferDescriptor = 12 +) + +// TxBuffer is the descriptor of a transmit buffer. +type TxBuffer struct { + Next *TxBuffer + Offset uint64 + Size uint32 +} + +// Tx is a transmit queue. It is implemented with one tx and one rx pipe: the +// tx pipe is used to request the transmission of packets, while the rx pipe +// is used to receive which transmissions have completed. +// +// This struct is thread-compatible. +type Tx struct { + tx pipe.Tx + rx pipe.Rx +} + +// Init initializes the transmit queue with the given pipes. +func (t *Tx) Init(tx, rx []byte) { + t.tx.Init(tx) + t.rx.Init(rx) +} + +// Enqueue queues the given linked list of buffers for transmission as one +// packet. While it is queued, the caller must not modify them. +func (t *Tx) Enqueue(id uint64, totalDataLen, bufferCount uint32, buffer *TxBuffer) bool { + // Reserve room in the tx pipe. + totalLen := sizeOfPacketHeader + uint64(bufferCount)*sizeOfBufferDescriptor + + b := t.tx.Push(totalLen) + if b == nil { + return false + } + + // Initialize the packet and buffer descriptors. + binary.LittleEndian.PutUint64(b[packetID:], id) + binary.LittleEndian.PutUint32(b[packetSize:], totalDataLen) + binary.LittleEndian.PutUint32(b[packetReserved:], 0) + + offset := sizeOfPacketHeader + for i := bufferCount; i != 0; i-- { + binary.LittleEndian.PutUint64(b[offset+bufferOffset:], buffer.Offset) + binary.LittleEndian.PutUint32(b[offset+bufferSize:], buffer.Size) + offset += sizeOfBufferDescriptor + buffer = buffer.Next + } + + t.tx.Flush() + + return true +} + +// CompletedPacket returns the id of the last completed transmission. The +// returned id, if any, refers to a value passed on a previous call to +// Enqueue(). +func (t *Tx) CompletedPacket() (id uint64, ok bool) { + for { + b := t.rx.Pull() + if b == nil { + return 0, false + } + + if len(b) != 8 { + t.rx.Flush() + log.Warningf("Ignoring completed packet: size (%v) is less than expected (%v)", len(b), 8) + continue + } + + v := binary.LittleEndian.Uint64(b) + + t.rx.Flush() + + return v, true + } +} + +// Bytes returns the byte slices on which the queue operates. +func (t *Tx) Bytes() (tx, rx []byte) { + return t.tx.Bytes(), t.rx.Bytes() +} + +// TxPacketInfo holds information about a packet sent on a tx queue. +type TxPacketInfo struct { + ID uint64 + Size uint32 + Reserved uint32 + BufferCount int +} + +// DecodeTxPacketHeader decodes the header of a packet sent over a tx queue. +func DecodeTxPacketHeader(b []byte) TxPacketInfo { + return TxPacketInfo{ + ID: binary.LittleEndian.Uint64(b[packetID:]), + Size: binary.LittleEndian.Uint32(b[packetSize:]), + Reserved: binary.LittleEndian.Uint32(b[packetReserved:]), + BufferCount: (len(b) - sizeOfPacketHeader) / sizeOfBufferDescriptor, + } +} + +// DecodeTxBufferHeader decodes the header of the i-th buffer of a packet sent +// over a tx queue. +func DecodeTxBufferHeader(b []byte, i int) TxBuffer { + b = b[sizeOfPacketHeader+i*sizeOfBufferDescriptor:] + return TxBuffer{ + Offset: binary.LittleEndian.Uint64(b[bufferOffset:]), + Size: binary.LittleEndian.Uint32(b[bufferSize:]), + } +} + +// EncodeTxCompletion encodes a tx completion header. +func EncodeTxCompletion(b []byte, id uint64) { + binary.LittleEndian.PutUint64(b, id) +} diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go new file mode 100755 index 000000000..eec11e4cb --- /dev/null +++ b/pkg/tcpip/link/sharedmem/rx.go @@ -0,0 +1,159 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package sharedmem + +import ( + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +// rx holds all state associated with an rx queue. +type rx struct { + data []byte + sharedData []byte + q queue.Rx + eventFD int +} + +// init initializes all state needed by the rx queue based on the information +// provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (r *rx) init(mtu uint32, c *QueueConfig) error { + // Map in all buffers. + txPipe, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + + rxPipe, err := getBuffer(c.RxPipeFD) + if err != nil { + syscall.Munmap(txPipe) + return err + } + + data, err := getBuffer(c.DataFD) + if err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + return err + } + + sharedData, err := getBuffer(c.SharedDataFD) + if err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + syscall.Munmap(data) + return err + } + + // Duplicate the eventFD so that caller can close it but we can still + // use it. + efd, err := syscall.Dup(c.EventFD) + if err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + syscall.Munmap(data) + syscall.Munmap(sharedData) + return err + } + + // Set the eventfd as non-blocking. + if err := syscall.SetNonblock(efd, true); err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + syscall.Munmap(data) + syscall.Munmap(sharedData) + syscall.Close(efd) + return err + } + + // Initialize state based on buffers. + r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData)) + r.data = data + r.eventFD = efd + r.sharedData = sharedData + + return nil +} + +// cleanup releases all resources allocated during init(). It must only be +// called if init() has previously succeeded. +func (r *rx) cleanup() { + a, b := r.q.Bytes() + syscall.Munmap(a) + syscall.Munmap(b) + + syscall.Munmap(r.data) + syscall.Munmap(r.sharedData) + syscall.Close(r.eventFD) +} + +// postAndReceive posts the provided buffers (if any), and then tries to read +// from the receive queue. +// +// Capacity permitting, it reuses the posted buffer slice to store the buffers +// that were read as well. +// +// This function will block if there aren't any available packets. +func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.RxBuffer, uint32) { + // Post the buffers first. If we cannot post, sleep until we can. We + // never post more than will fit concurrently, so it's safe to wait + // until enough room is available. + if len(b) != 0 && !r.q.PostBuffers(b) { + r.q.EnableNotification() + for !r.q.PostBuffers(b) { + var tmp [8]byte + rawfile.BlockingRead(r.eventFD, tmp[:]) + if atomic.LoadUint32(stopRequested) != 0 { + r.q.DisableNotification() + return nil, 0 + } + } + r.q.DisableNotification() + } + + // Read the next set of descriptors. + b, n := r.q.Dequeue(b[:0]) + if len(b) != 0 { + return b, n + } + + // Data isn't immediately available. Enable eventfd notifications. + r.q.EnableNotification() + for { + b, n = r.q.Dequeue(b) + if len(b) != 0 { + break + } + + // Wait for notification. + var tmp [8]byte + rawfile.BlockingRead(r.eventFD, tmp[:]) + if atomic.LoadUint32(stopRequested) != 0 { + r.q.DisableNotification() + return nil, 0 + } + } + r.q.DisableNotification() + + return b, n +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go new file mode 100755 index 000000000..080f9d667 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -0,0 +1,289 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package sharedmem provides the implemention of data-link layer endpoints +// backed by shared memory. +// +// Shared memory endpoints can be used in the networking stack by calling New() +// to create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +package sharedmem + +import ( + "sync" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// QueueConfig holds all the file descriptors needed to describe a tx or rx +// queue over shared memory. It is used when creating new shared memory +// endpoints to describe tx and rx queues. +type QueueConfig struct { + // DataFD is a file descriptor for the file that contains the data to + // be transmitted via this queue. Descriptors contain offsets within + // this file. + DataFD int + + // EventFD is a file descriptor for the event that is signaled when + // data is becomes available in this queue. + EventFD int + + // TxPipeFD is a file descriptor for the tx pipe associated with the + // queue. + TxPipeFD int + + // RxPipeFD is a file descriptor for the rx pipe associated with the + // queue. + RxPipeFD int + + // SharedDataFD is a file descriptor for the file that contains shared + // state between the two ends of the queue. This data specifies, for + // example, whether EventFD signaling is enabled or disabled. + SharedDataFD int +} + +type endpoint struct { + // mtu (maximum transmission unit) is the maximum size of a packet. + mtu uint32 + + // bufferSize is the size of each individual buffer. + bufferSize uint32 + + // addr is the local address of this endpoint. + addr tcpip.LinkAddress + + // rx is the receive queue. + rx rx + + // stopRequested is to be accessed atomically only, and determines if + // the worker goroutines should stop. + stopRequested uint32 + + // Wait group used to indicate that all workers have stopped. + completed sync.WaitGroup + + // mu protects the following fields. + mu sync.Mutex + + // tx is the transmit queue. + tx tx + + // workerStarted specifies whether the worker goroutine was started. + workerStarted bool +} + +// New creates a new shared-memory-based endpoint. Buffers will be broken up +// into buffers of "bufferSize" bytes. +func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { + e := &endpoint{ + mtu: mtu, + bufferSize: bufferSize, + addr: addr, + } + + if err := e.tx.init(bufferSize, &tx); err != nil { + return nil, err + } + + if err := e.rx.init(bufferSize, &rx); err != nil { + e.tx.cleanup() + return nil, err + } + + return e, nil +} + +// Close frees all resources associated with the endpoint. +func (e *endpoint) Close() { + // Tell dispatch goroutine to stop, then write to the eventfd so that + // it wakes up in case it's sleeping. + atomic.StoreUint32(&e.stopRequested, 1) + syscall.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Cleanup the queues inline if the worker hasn't started yet; we also + // know it won't start from now on because stopRequested is set to 1. + e.mu.Lock() + workerPresent := e.workerStarted + e.mu.Unlock() + + if !workerPresent { + e.tx.cleanup() + e.rx.cleanup() + } +} + +// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have +// stopped after a Close() call. +func (e *endpoint) Wait() { + e.completed.Wait() +} + +// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that +// reads packets from the rx queue. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { + e.workerStarted = true + e.completed.Add(1) + // Link endpoints are not savable. When transportation endpoints + // are saved, they stop sending outgoing packets and all + // incoming packets are rejected. + go e.dispatchLoop(dispatcher) // S/R-SAFE: see above. + } + e.mu.Unlock() +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.workerStarted +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *endpoint) MTU() uint32 { + return e.mtu - header.EthernetMinimumSize +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the +// ethernet frame header size. +func (*endpoint) MaxHeaderLength() uint16 { + return header.EthernetMinimumSize +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local +// link address. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.addr +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + // Add the ethernet header here. + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + + v := pkt.Data.ToView() + // Transmit the packet. + e.mu.Lock() + ok := e.tx.transmit(pkt.Header.View(), v) + e.mu.Unlock() + + if !ok { + return tcpip.ErrWouldBlock + } + + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + v := vv.ToView() + // Transmit the packet. + e.mu.Lock() + ok := e.tx.transmit(v, buffer.View{}) + e.mu.Unlock() + + if !ok { + return tcpip.ErrWouldBlock + } + + return nil +} + +// dispatchLoop reads packets from the rx queue in a loop and dispatches them +// to the network stack. +func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { + // Post initial set of buffers. + limit := e.rx.q.PostedBuffersLimit() + if l := uint64(len(e.rx.data)) / uint64(e.bufferSize); limit > l { + limit = l + } + for i := uint64(0); i < limit; i++ { + b := queue.RxBuffer{ + Offset: i * uint64(e.bufferSize), + Size: e.bufferSize, + ID: i, + } + if !e.rx.q.PostBuffers([]queue.RxBuffer{b}) { + log.Warningf("Unable to post %v-th buffer", i) + } + } + + // Read in a loop until a stop is requested. + var rxb []queue.RxBuffer + for atomic.LoadUint32(&e.stopRequested) == 0 { + var n uint32 + rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested) + + // Copy data from the shared area to its own buffer, then + // prepare to repost the buffer. + b := make([]byte, n) + offset := uint32(0) + for i := range rxb { + copy(b[offset:], e.rx.data[rxb[i].Offset:][:rxb[i].Size]) + offset += rxb[i].Size + + rxb[i].Size = e.bufferSize + } + + if n < header.EthernetMinimumSize { + continue + } + + // Send packet up the stack. + eth := header.Ethernet(b[:header.EthernetMinimumSize]) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) + } + + // Clean state. + e.tx.cleanup() + e.rx.cleanup() + + e.completed.Done() +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go new file mode 100755 index 000000000..e5c542528 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package sharedmem + diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go new file mode 100755 index 000000000..f7e816a41 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go @@ -0,0 +1,25 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sharedmem + +import ( + "unsafe" +) + +// sharedDataPointer converts the shared data slice into a pointer so that it +// can be used in atomic operations. +func sharedDataPointer(sharedData []byte) *uint32 { + return (*uint32)(unsafe.Pointer(&sharedData[0:4][0])) +} diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go new file mode 100755 index 000000000..6b8d7859d --- /dev/null +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -0,0 +1,272 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sharedmem + +import ( + "math" + "syscall" + + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +const ( + nilID = math.MaxUint64 +) + +// tx holds all state associated with a tx queue. +type tx struct { + data []byte + q queue.Tx + ids idManager + bufs bufferManager +} + +// init initializes all state needed by the tx queue based on the information +// provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (t *tx) init(mtu uint32, c *QueueConfig) error { + // Map in all buffers. + txPipe, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + + rxPipe, err := getBuffer(c.RxPipeFD) + if err != nil { + syscall.Munmap(txPipe) + return err + } + + data, err := getBuffer(c.DataFD) + if err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + return err + } + + // Initialize state based on buffers. + t.q.Init(txPipe, rxPipe) + t.ids.init() + t.bufs.init(0, len(data), int(mtu)) + t.data = data + + return nil +} + +// cleanup releases all resources allocated during init(). It must only be +// called if init() has previously succeeded. +func (t *tx) cleanup() { + a, b := t.q.Bytes() + syscall.Munmap(a) + syscall.Munmap(b) + syscall.Munmap(t.data) +} + +// transmit sends a packet made up of up to two buffers. Returns a boolean that +// specifies whether the packet was successfully transmitted. +func (t *tx) transmit(a, b []byte) bool { + // Pull completions from the tx queue and add their buffers back to the + // pool so that we can reuse them. + for { + id, ok := t.q.CompletedPacket() + if !ok { + break + } + + if buf := t.ids.remove(id); buf != nil { + t.bufs.free(buf) + } + } + + bSize := t.bufs.entrySize + total := uint32(len(a) + len(b)) + bufCount := (total + bSize - 1) / bSize + + // Allocate enough buffers to hold all the data. + var buf *queue.TxBuffer + for i := bufCount; i != 0; i-- { + b := t.bufs.alloc() + if b == nil { + // Failed to get all buffers. Return to the pool + // whatever we had managed to get. + if buf != nil { + t.bufs.free(buf) + } + return false + } + b.Next = buf + buf = b + } + + // Copy data into allocated buffers. + nBuf := buf + var dBuf []byte + for _, data := range [][]byte{a, b} { + for len(data) > 0 { + if len(dBuf) == 0 { + dBuf = t.data[nBuf.Offset:][:nBuf.Size] + nBuf = nBuf.Next + } + n := copy(dBuf, data) + data = data[n:] + dBuf = dBuf[n:] + } + } + + // Get an id for this packet and send it out. + id := t.ids.add(buf) + if !t.q.Enqueue(id, total, bufCount, buf) { + t.ids.remove(id) + t.bufs.free(buf) + return false + } + + return true +} + +// getBuffer returns a memory region mapped to the full contents of the given +// file descriptor. +func getBuffer(fd int) ([]byte, error) { + var s syscall.Stat_t + if err := syscall.Fstat(fd, &s); err != nil { + return nil, err + } + + // Check that size doesn't overflow an int. + if s.Size > int64(^uint(0)>>1) { + return nil, syscall.EDOM + } + + return syscall.Mmap(fd, 0, int(s.Size), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED|syscall.MAP_FILE) +} + +// idDescriptor is used by idManager to either point to a tx buffer (in case +// the ID is assigned) or to the next free element (if the id is not assigned). +type idDescriptor struct { + buf *queue.TxBuffer + nextFree uint64 +} + +// idManager is a manager of tx buffer identifiers. It assigns unique IDs to +// tx buffers that are added to it; the IDs can only be reused after they have +// been removed. +// +// The ID assignments are stored so that the tx buffers can be retrieved from +// the IDs previously assigned to them. +type idManager struct { + // ids is a slice containing all tx buffers. The ID is the index into + // this slice. + ids []idDescriptor + + // freeList a list of free IDs. + freeList uint64 +} + +// init initializes the id manager. +func (m *idManager) init() { + m.freeList = nilID +} + +// add assigns an ID to the given tx buffer. +func (m *idManager) add(b *queue.TxBuffer) uint64 { + if i := m.freeList; i != nilID { + // There is an id available in the free list, just use it. + m.ids[i].buf = b + m.freeList = m.ids[i].nextFree + return i + } + + // We need to expand the id descriptor. + m.ids = append(m.ids, idDescriptor{buf: b}) + return uint64(len(m.ids) - 1) +} + +// remove retrieves the tx buffer associated with the given ID, and removes the +// ID from the assigned table so that it can be reused in the future. +func (m *idManager) remove(i uint64) *queue.TxBuffer { + if i >= uint64(len(m.ids)) { + return nil + } + + desc := &m.ids[i] + b := desc.buf + if b == nil { + // The provided id is not currently assigned. + return nil + } + + desc.buf = nil + desc.nextFree = m.freeList + m.freeList = i + + return b +} + +// bufferManager manages a buffer region broken up into smaller, equally sized +// buffers. Smaller buffers can be allocated and freed. +type bufferManager struct { + freeList *queue.TxBuffer + curOffset uint64 + limit uint64 + entrySize uint32 +} + +// init initializes the buffer manager. +func (b *bufferManager) init(initialOffset, size, entrySize int) { + b.freeList = nil + b.curOffset = uint64(initialOffset) + b.limit = uint64(initialOffset + size/entrySize*entrySize) + b.entrySize = uint32(entrySize) +} + +// alloc allocates a buffer from the manager, if one is available. +func (b *bufferManager) alloc() *queue.TxBuffer { + if b.freeList != nil { + // There is a descriptor ready for reuse in the free list. + d := b.freeList + b.freeList = d.Next + d.Next = nil + return d + } + + if b.curOffset < b.limit { + // There is room available in the never-used range, so create + // a new descriptor for it. + d := &queue.TxBuffer{ + Offset: b.curOffset, + Size: b.entrySize, + } + b.curOffset += uint64(b.entrySize) + return d + } + + return nil +} + +// free returns all buffers in the list to the buffer manager so that they can +// be reused. +func (b *bufferManager) free(d *queue.TxBuffer) { + // Find the last buffer in the list. + last := d + for last.Next != nil { + last = last.Next + } + + // Push list onto free list. + last.Next = b.freeList + b.freeList = d +} diff --git a/pkg/tcpip/link/tun/tun_state_autogen.go b/pkg/tcpip/link/tun/tun_state_autogen.go new file mode 100755 index 000000000..7ded170f6 --- /dev/null +++ b/pkg/tcpip/link/tun/tun_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package tun + diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go new file mode 100755 index 000000000..09ca9b527 --- /dev/null +++ b/pkg/tcpip/link/tun/tun_unsafe.go @@ -0,0 +1,63 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package tun contains methods to open TAP and TUN devices. +package tun + +import ( + "syscall" + "unsafe" +) + +// Open opens the specified TUN device, sets it to non-blocking mode, and +// returns its file descriptor. +func Open(name string) (int, error) { + return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI) +} + +// OpenTAP opens the specified TAP device, sets it to non-blocking mode, and +// returns its file descriptor. +func OpenTAP(name string) (int, error) { + return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI) +} + +func open(name string, flags uint16) (int, error) { + fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0) + if err != nil { + return -1, err + } + + var ifr struct { + name [16]byte + flags uint16 + _ [22]byte + } + + copy(ifr.name[:], name) + ifr.flags = flags + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) + if errno != 0 { + syscall.Close(fd) + return -1, errno + } + + if err = syscall.SetNonblock(fd, true); err != nil { + syscall.Close(fd) + return -1, err + } + + return fd, nil +} diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go new file mode 100755 index 000000000..a8de38979 --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable.go @@ -0,0 +1,149 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package waitable provides the implementation of data-link layer endpoints +// that wrap other endpoints, and can wait for inflight calls to WritePacket or +// DeliverNetworkPacket to finish (and new ones to be prevented). +// +// Waitable endpoints can be used in the networking stack by calling New(eID) to +// create a new endpoint, where eID is the ID of the endpoint being wrapped, +// and then passing it as an argument to Stack.CreateNIC(). +package waitable + +import ( + "gvisor.dev/gvisor/pkg/gate" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// Endpoint is a waitable link-layer endpoint. +type Endpoint struct { + dispatchGate gate.Gate + dispatcher stack.NetworkDispatcher + + writeGate gate.Gate + lower stack.LinkEndpoint +} + +// New creates a new waitable link-layer endpoint. It wraps around another +// endpoint and allows the caller to block new write/dispatch calls and wait for +// the inflight ones to finish before returning. +func New(lower stack.LinkEndpoint) *Endpoint { + return &Endpoint{ + lower: lower, + } +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. +// It is called by the link-layer endpoint being wrapped when a packet arrives, +// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't +// been called. +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + +// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and +// registers with the lower endpoint as its dispatcher so that "e" is called +// for inbound packets. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + e.lower.Attach(e) +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *Endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the +// lower endpoint. +func (e *Endpoint) MTU() uint32 { + return e.lower.MTU() +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the +// request to the lower endpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.lower.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just +// forwards the request to the lower endpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return e.lower.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the +// request to the lower endpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.lower.LinkAddress() +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WritePacket(r, gso, protocol, pkt) + e.writeGate.Leave() + return err +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if !e.writeGate.Enter() { + return len(pkts), nil + } + + n, err := e.lower.WritePackets(r, gso, pkts, protocol) + e.writeGate.Leave() + return n, err +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WriteRawPacket(vv) + e.writeGate.Leave() + return err +} + +// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, +// and waits for inflight ones to finish before returning. +func (e *Endpoint) WaitWrite() { + e.writeGate.Close() +} + +// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the +// actual dispatcher, and waits for inflight ones to finish before returning. +func (e *Endpoint) WaitDispatch() { + e.dispatchGate.Close() +} + +// Wait implements stack.LinkEndpoint.Wait. +func (e *Endpoint) Wait() {} diff --git a/pkg/tcpip/link/waitable/waitable_state_autogen.go b/pkg/tcpip/link/waitable/waitable_state_autogen.go new file mode 100755 index 000000000..2029f4a1b --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package waitable + diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go new file mode 100755 index 000000000..93712cd45 --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -0,0 +1,349 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcpconntrack implements a TCP connection tracking object. It allows +// users with access to a segment stream to figure out when a connection is +// established, reset, and closed (and in the last case, who closed first). +package tcpconntrack + +import ( + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// Result is returned when the state of a TCB is updated in response to an +// inbound or outbound segment. +type Result int + +const ( + // ResultDrop indicates that the segment should be dropped. + ResultDrop Result = iota + + // ResultConnecting indicates that the connection remains in a + // connecting state. + ResultConnecting + + // ResultAlive indicates that the connection remains alive (connected). + ResultAlive + + // ResultReset indicates that the connection was reset. + ResultReset + + // ResultClosedByPeer indicates that the connection was gracefully + // closed, and the inbound stream was closed first. + ResultClosedByPeer + + // ResultClosedBySelf indicates that the connection was gracefully + // closed, and the outbound stream was closed first. + ResultClosedBySelf +) + +// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP +// connection and inform the caller when the connection has been closed. +type TCB struct { + inbound stream + outbound stream + + // State handlers. + handlerInbound func(*TCB, header.TCP) Result + handlerOutbound func(*TCB, header.TCP) Result + + // firstFin holds a pointer to the first stream to send a FIN. + firstFin *stream + + // state is the current state of the stream. + state Result +} + +// Init initializes the state of the TCB according to the initial SYN. +func (t *TCB) Init(initialSyn header.TCP) Result { + t.handlerInbound = synSentStateInbound + t.handlerOutbound = synSentStateOutbound + + iss := seqnum.Value(initialSyn.SequenceNumber()) + t.outbound.una = iss + t.outbound.nxt = iss.Add(logicalLen(initialSyn)) + t.outbound.end = t.outbound.nxt + + // Even though "end" is a sequence number, we don't know the initial + // receive sequence number yet, so we store the window size until we get + // a SYN from the peer. + t.inbound.una = 0 + t.inbound.nxt = 0 + t.inbound.end = seqnum.Value(initialSyn.WindowSize()) + t.state = ResultConnecting + return t.state +} + +// UpdateStateInbound updates the state of the TCB based on the supplied inbound +// segment. +func (t *TCB) UpdateStateInbound(tcp header.TCP) Result { + st := t.handlerInbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// UpdateStateOutbound updates the state of the TCB based on the supplied +// outbound segment. +func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { + st := t.handlerOutbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// IsAlive returns true as long as the connection is established(Alive) +// or connecting state. +func (t *TCB) IsAlive() bool { + return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed()) +} + +// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream. +func (t *TCB) OutboundSendSequenceNumber() seqnum.Value { + return t.outbound.nxt +} + +// InboundSendSequenceNumber returns the snd.NXT for the inbound stream. +func (t *TCB) InboundSendSequenceNumber() seqnum.Value { + return t.inbound.nxt +} + +// adapResult modifies the supplied "Result" according to the state of the TCB; +// if r is anything other than "Alive", or if one of the streams isn't closed +// yet, it is returned unmodified. Otherwise it's converted to either +// ClosedBySelf or ClosedByPeer depending on which stream was closed first. +func (t *TCB) adaptResult(r Result) Result { + // Check the unmodified case. + if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() { + return r + } + + // Find out which was closed first. + if t.firstFin == &t.outbound { + return ResultClosedBySelf + } + + return ResultClosedByPeer +} + +// synSentStateInbound is the state handler for inbound segments when the +// connection is in SYN-SENT state. +func synSentStateInbound(t *TCB, tcp header.TCP) Result { + flags := tcp.Flags() + ackPresent := flags&header.TCPFlagAck != 0 + ack := seqnum.Value(tcp.AckNumber()) + + // Ignore segment if ack is present but not acceptable. + if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) { + return ResultConnecting + } + + // If reset is specified, we will let the packet through no matter what + // but we will also destroy the connection if the ACK is present (and + // implicitly acceptable). + if flags&header.TCPFlagRst != 0 { + if ackPresent { + t.inbound.rstSeen = true + return ResultReset + } + return ResultConnecting + } + + // Ignore segment if SYN is not set. + if flags&header.TCPFlagSyn == 0 { + return ResultConnecting + } + + // Update state informed by this SYN. + irs := seqnum.Value(tcp.SequenceNumber()) + t.inbound.una = irs + t.inbound.nxt = irs.Add(logicalLen(tcp)) + t.inbound.end += irs + + t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize())) + + // If the ACK was set (it is acceptable), update our unacknowledgement + // tracking. + if ackPresent { + // Advance the "una" and "end" indices of the outbound stream. + if t.outbound.una.LessThan(ack) { + t.outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) { + t.outbound.end = end + } + } + + // Update handlers so that new calls will be handled by new state. + t.handlerInbound = allOtherInbound + t.handlerOutbound = allOtherOutbound + + return ResultAlive +} + +// synSentStateOutbound is the state handler for outbound segments when the +// connection is in SYN-SENT state. +func synSentStateOutbound(t *TCB, tcp header.TCP) Result { + // Drop outbound segments that aren't retransmits of the original one. + if tcp.Flags() != header.TCPFlagSyn || + tcp.SequenceNumber() != uint32(t.outbound.una) { + return ResultDrop + } + + // Update the receive window. We only remember the largest value seen. + if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end { + t.inbound.end = wnd + } + + return ResultConnecting +} + +// update updates the state of inbound and outbound streams, given the supplied +// inbound segment. For outbound segments, this same function can be called with +// swapped inbound/outbound streams. +func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result { + // Ignore segments out of the window. + s := seqnum.Value(tcp.SequenceNumber()) + if !inbound.acceptable(s, dataLen(tcp)) { + return ResultAlive + } + + flags := tcp.Flags() + if flags&header.TCPFlagRst != 0 { + inbound.rstSeen = true + return ResultReset + } + + // Ignore segments that don't have the ACK flag, and those with the SYN + // flag. + if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 { + return ResultAlive + } + + // Ignore segments that acknowledge not yet sent data. + ack := seqnum.Value(tcp.AckNumber()) + if outbound.nxt.LessThan(ack) { + return ResultAlive + } + + // Advance the "una" and "end" indices of the outbound stream. + if outbound.una.LessThan(ack) { + outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) { + outbound.end = end + } + + // Advance the "nxt" index of the inbound stream. + end := s.Add(logicalLen(tcp)) + if inbound.nxt.LessThan(end) { + inbound.nxt = end + } + + // Note the index of the FIN segment. And stash away a pointer to the + // first stream to see a FIN. + if flags&header.TCPFlagFin != 0 && !inbound.finSeen { + inbound.finSeen = true + inbound.fin = end - 1 + + if *firstFin == nil { + *firstFin = inbound + } + } + + return ResultAlive +} + +// allOtherInbound is the state handler for inbound segments in all states +// except SYN-SENT. +func allOtherInbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin)) +} + +// allOtherOutbound is the state handler for outbound segments in all states +// except SYN-SENT. +func allOtherOutbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin)) +} + +// streams holds the state of a TCP unidirectional stream. +type stream struct { + // The interval [una, end) is the allowed interval as defined by the + // receiver, i.e., anything less than una has already been acknowledged + // and anything greater than or equal to end is beyond the receiver + // window. The interval [una, nxt) is the acknowledgable range, whose + // right edge indicates the sequence number of the next byte to be sent + // by the sender, i.e., anything greater than or equal to nxt hasn't + // been sent yet. + una seqnum.Value + nxt seqnum.Value + end seqnum.Value + + // finSeen indicates if a FIN has already been sent on this stream. + finSeen bool + + // fin is the sequence number of the FIN. It is only valid after finSeen + // is set to true. + fin seqnum.Value + + // rstSeen indicates if a RST has already been sent on this stream. + rstSeen bool +} + +// acceptable determines if the segment with the given sequence number and data +// length is acceptable, i.e., if it's within the [una, end) window or, in case +// the window is zero, if it's a packet with no payload and sequence number +// equal to una. +func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + wnd := s.una.Size(s.end) + if wnd == 0 { + return segLen == 0 && segSeq == s.una + } + + // Make sure [segSeq, seqSeq+segLen) is non-empty. + if segLen == 0 { + segLen = 1 + } + + return seqnum.Overlap(s.una, wnd, segSeq, segLen) +} + +// closed determines if the stream has already been closed. This happens when +// a FIN has been set by the sender and acknowledged by the receiver. +func (s *stream) closed() bool { + return s.finSeen && s.fin.LessThan(s.una) +} + +// dataLen returns the length of the TCP segment payload. +func dataLen(tcp header.TCP) seqnum.Size { + return seqnum.Size(len(tcp) - int(tcp.DataOffset())) +} + +// logicalLen calculates the logical length of the TCP segment. +func logicalLen(tcp header.TCP) seqnum.Size { + l := dataLen(tcp) + flags := tcp.Flags() + if flags&header.TCPFlagSyn != 0 { + l++ + } + if flags&header.TCPFlagFin != 0 { + l++ + } + return l +} diff --git a/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go new file mode 100755 index 000000000..f3c60c272 --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package tcpconntrack + |