summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/adapters/gonet
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/adapters/gonet')
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD38
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go722
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go683
3 files changed, 0 insertions, 1443 deletions
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
deleted file mode 100644
index 78df5a0b1..000000000
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ /dev/null
@@ -1,38 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "gonet",
- srcs = ["gonet.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet",
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "gonet_test",
- size = "small",
- srcs = ["gonet_test.go"],
- embed = [":gonet"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@org_golang_x_net//nettest:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
deleted file mode 100644
index cd6ce930a..000000000
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ /dev/null
@@ -1,722 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package gonet provides a Go net package compatible wrapper for a tcpip stack.
-package gonet
-
-import (
- "context"
- "errors"
- "io"
- "net"
- "sync"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-var (
- errCanceled = errors.New("operation canceled")
- errWouldBlock = errors.New("operation would block")
-)
-
-// timeoutError is how the net package reports timeouts.
-type timeoutError struct{}
-
-func (e *timeoutError) Error() string { return "i/o timeout" }
-func (e *timeoutError) Timeout() bool { return true }
-func (e *timeoutError) Temporary() bool { return true }
-
-// A Listener is a wrapper around a tcpip endpoint that implements
-// net.Listener.
-type Listener struct {
- stack *stack.Stack
- ep tcpip.Endpoint
- wq *waiter.Queue
- cancel chan struct{}
-}
-
-// NewListener creates a new Listener.
-func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
- // Create TCP endpoint, bind it, then start listening.
- var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
- if err != nil {
- return nil, errors.New(err.String())
- }
-
- if err := ep.Bind(addr); err != nil {
- ep.Close()
- return nil, &net.OpError{
- Op: "bind",
- Net: "tcp",
- Addr: fullToTCPAddr(addr),
- Err: errors.New(err.String()),
- }
- }
-
- if err := ep.Listen(10); err != nil {
- ep.Close()
- return nil, &net.OpError{
- Op: "listen",
- Net: "tcp",
- Addr: fullToTCPAddr(addr),
- Err: errors.New(err.String()),
- }
- }
-
- return &Listener{
- stack: s,
- ep: ep,
- wq: &wq,
- cancel: make(chan struct{}),
- }, nil
-}
-
-// Close implements net.Listener.Close.
-func (l *Listener) Close() error {
- l.ep.Close()
- return nil
-}
-
-// Shutdown stops the HTTP server.
-func (l *Listener) Shutdown() {
- l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
- close(l.cancel) // broadcast cancellation
-}
-
-// Addr implements net.Listener.Addr.
-func (l *Listener) Addr() net.Addr {
- a, err := l.ep.GetLocalAddress()
- if err != nil {
- return nil
- }
- return fullToTCPAddr(a)
-}
-
-type deadlineTimer struct {
- // mu protects the fields below.
- mu sync.Mutex
-
- readTimer *time.Timer
- readCancelCh chan struct{}
- writeTimer *time.Timer
- writeCancelCh chan struct{}
-}
-
-func (d *deadlineTimer) init() {
- d.readCancelCh = make(chan struct{})
- d.writeCancelCh = make(chan struct{})
-}
-
-func (d *deadlineTimer) readCancel() <-chan struct{} {
- d.mu.Lock()
- c := d.readCancelCh
- d.mu.Unlock()
- return c
-}
-func (d *deadlineTimer) writeCancel() <-chan struct{} {
- d.mu.Lock()
- c := d.writeCancelCh
- d.mu.Unlock()
- return c
-}
-
-// setDeadline contains the shared logic for setting a deadline.
-//
-// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and
-// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and
-// deadlineTimer.writeTimer.
-//
-// setDeadline must only be called while holding d.mu.
-func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) {
- if *timer != nil && !(*timer).Stop() {
- *cancelCh = make(chan struct{})
- }
-
- // Create a new channel if we already closed it due to setting an already
- // expired time. We won't race with the timer because we already handled
- // that above.
- select {
- case <-*cancelCh:
- *cancelCh = make(chan struct{})
- default:
- }
-
- // "A zero value for t means I/O operations will not time out."
- // - net.Conn.SetDeadline
- if t.IsZero() {
- return
- }
-
- timeout := t.Sub(time.Now())
- if timeout <= 0 {
- close(*cancelCh)
- return
- }
-
- // Timer.Stop returns whether or not the AfterFunc has started, but
- // does not indicate whether or not it has completed. Make a copy of
- // the cancel channel to prevent this code from racing with the next
- // call of setDeadline replacing *cancelCh.
- ch := *cancelCh
- *timer = time.AfterFunc(timeout, func() {
- close(ch)
- })
-}
-
-// SetReadDeadline implements net.Conn.SetReadDeadline and
-// net.PacketConn.SetReadDeadline.
-func (d *deadlineTimer) SetReadDeadline(t time.Time) error {
- d.mu.Lock()
- d.setDeadline(&d.readCancelCh, &d.readTimer, t)
- d.mu.Unlock()
- return nil
-}
-
-// SetWriteDeadline implements net.Conn.SetWriteDeadline and
-// net.PacketConn.SetWriteDeadline.
-func (d *deadlineTimer) SetWriteDeadline(t time.Time) error {
- d.mu.Lock()
- d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
- d.mu.Unlock()
- return nil
-}
-
-// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline.
-func (d *deadlineTimer) SetDeadline(t time.Time) error {
- d.mu.Lock()
- d.setDeadline(&d.readCancelCh, &d.readTimer, t)
- d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
- d.mu.Unlock()
- return nil
-}
-
-// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
-// interface.
-type Conn struct {
- deadlineTimer
-
- wq *waiter.Queue
- ep tcpip.Endpoint
-
- // readMu serializes reads and implicitly protects read.
- //
- // Lock ordering:
- // If both readMu and deadlineTimer.mu are to be used in a single
- // request, readMu must be acquired before deadlineTimer.mu.
- readMu sync.Mutex
-
- // read contains bytes that have been read from the endpoint,
- // but haven't yet been returned.
- read buffer.View
-}
-
-// NewConn creates a new Conn.
-func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
- c := &Conn{
- wq: wq,
- ep: ep,
- }
- c.deadlineTimer.init()
- return c
-}
-
-// Accept implements net.Conn.Accept.
-func (l *Listener) Accept() (net.Conn, error) {
- n, wq, err := l.ep.Accept()
-
- if err == tcpip.ErrWouldBlock {
- // Create wait queue entry that notifies a channel.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- l.wq.EventRegister(&waitEntry, waiter.EventIn)
- defer l.wq.EventUnregister(&waitEntry)
-
- for {
- n, wq, err = l.ep.Accept()
-
- if err != tcpip.ErrWouldBlock {
- break
- }
-
- select {
- case <-l.cancel:
- return nil, errCanceled
- case <-notifyCh:
- }
- }
- }
-
- if err != nil {
- return nil, &net.OpError{
- Op: "accept",
- Net: "tcp",
- Addr: l.Addr(),
- Err: errors.New(err.String()),
- }
- }
-
- return NewConn(wq, n), nil
-}
-
-type opErrorer interface {
- newOpError(op string, err error) *net.OpError
-}
-
-// commonRead implements the common logic between net.Conn.Read and
-// net.PacketConn.ReadFrom.
-func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) {
- select {
- case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
- default:
- }
-
- read, _, err := ep.Read(addr)
-
- if err == tcpip.ErrWouldBlock {
- if dontWait {
- return nil, errWouldBlock
- }
- // Create wait queue entry that notifies a channel.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventIn)
- defer wq.EventUnregister(&waitEntry)
- for {
- read, _, err = ep.Read(addr)
- if err != tcpip.ErrWouldBlock {
- break
- }
- select {
- case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
- case <-notifyCh:
- }
- }
- }
-
- if err == tcpip.ErrClosedForReceive {
- return nil, io.EOF
- }
-
- if err != nil {
- return nil, errorer.newOpError("read", errors.New(err.String()))
- }
-
- return read, nil
-}
-
-// Read implements net.Conn.Read.
-func (c *Conn) Read(b []byte) (int, error) {
- c.readMu.Lock()
- defer c.readMu.Unlock()
-
- deadline := c.readCancel()
-
- numRead := 0
- for numRead != len(b) {
- if len(c.read) == 0 {
- var err error
- c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0)
- if err != nil {
- if numRead != 0 {
- return numRead, nil
- }
- return numRead, err
- }
- }
- n := copy(b[numRead:], c.read)
- c.read.TrimFront(n)
- numRead += n
- if len(c.read) == 0 {
- c.read = nil
- }
- }
- return numRead, nil
-}
-
-// Write implements net.Conn.Write.
-func (c *Conn) Write(b []byte) (int, error) {
- deadline := c.writeCancel()
-
- // Check if deadlineTimer has already expired.
- select {
- case <-deadline:
- return 0, c.newOpError("write", &timeoutError{})
- default:
- }
-
- v := buffer.NewViewFromBytes(b)
-
- // We must handle two soft failure conditions simultaneously:
- // 1. Write may write nothing and return tcpip.ErrWouldBlock.
- // If this happens, we need to register for notifications if we have
- // not already and wait to try again.
- // 2. Write may write fewer than the full number of bytes and return
- // without error. In this case we need to try writing the remaining
- // bytes again. I do not need to register for notifications.
- //
- // What is more, these two soft failure conditions can be interspersed.
- // There is no guarantee that all of the condition #1s will occur before
- // all of the condition #2s or visa-versa.
- var (
- err *tcpip.Error
- nbytes int
- reg bool
- notifyCh chan struct{}
- )
- for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
- if err == tcpip.ErrWouldBlock {
- if !reg {
- // Only register once.
- reg = true
-
- // Create wait queue entry that notifies a channel.
- var waitEntry waiter.Entry
- waitEntry, notifyCh = waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&waitEntry, waiter.EventOut)
- defer c.wq.EventUnregister(&waitEntry)
- } else {
- // Don't wait immediately after registration in case more data
- // became available between when we last checked and when we setup
- // the notification.
- select {
- case <-deadline:
- return nbytes, c.newOpError("write", &timeoutError{})
- case <-notifyCh:
- }
- }
- }
-
- var n int64
- var resCh <-chan struct{}
- n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- nbytes += int(n)
- v.TrimFront(int(n))
-
- if resCh != nil {
- select {
- case <-deadline:
- return nbytes, c.newOpError("write", &timeoutError{})
- case <-resCh:
- }
-
- n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- nbytes += int(n)
- v.TrimFront(int(n))
- }
- }
-
- if err == nil {
- return nbytes, nil
- }
-
- return nbytes, c.newOpError("write", errors.New(err.String()))
-}
-
-// Close implements net.Conn.Close.
-func (c *Conn) Close() error {
- c.ep.Close()
- return nil
-}
-
-// CloseRead shuts down the reading side of the TCP connection. Most callers
-// should just use Close.
-//
-// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
-func (c *Conn) CloseRead() error {
- if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
- return c.newOpError("close", errors.New(terr.String()))
- }
- return nil
-}
-
-// CloseWrite shuts down the writing side of the TCP connection. Most callers
-// should just use Close.
-//
-// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
-func (c *Conn) CloseWrite() error {
- if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
- return c.newOpError("close", errors.New(terr.String()))
- }
- return nil
-}
-
-// LocalAddr implements net.Conn.LocalAddr.
-func (c *Conn) LocalAddr() net.Addr {
- a, err := c.ep.GetLocalAddress()
- if err != nil {
- return nil
- }
- return fullToTCPAddr(a)
-}
-
-// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *Conn) RemoteAddr() net.Addr {
- a, err := c.ep.GetRemoteAddress()
- if err != nil {
- return nil
- }
- return fullToTCPAddr(a)
-}
-
-func (c *Conn) newOpError(op string, err error) *net.OpError {
- return &net.OpError{
- Op: op,
- Net: "tcp",
- Source: c.LocalAddr(),
- Addr: c.RemoteAddr(),
- Err: err,
- }
-}
-
-func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr {
- return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
-}
-
-func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
- return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
-}
-
-// DialTCP creates a new TCP Conn connected to the specified address.
-func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
- return DialContextTCP(context.Background(), s, addr, network)
-}
-
-// DialContextTCP creates a new TCP Conn connected to the specified address
-// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
- // Create TCP endpoint, then connect.
- var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
- if err != nil {
- return nil, errors.New(err.String())
- }
-
- // Create wait queue entry that notifies a channel.
- //
- // We do this unconditionally as Connect will always return an error.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventOut)
- defer wq.EventUnregister(&waitEntry)
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
- }
-
- err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
- select {
- case <-ctx.Done():
- ep.Close()
- return nil, ctx.Err()
- case <-notifyCh:
- }
-
- err = ep.GetSockOpt(tcpip.ErrorOption{})
- }
- if err != nil {
- ep.Close()
- return nil, &net.OpError{
- Op: "connect",
- Net: "tcp",
- Addr: fullToTCPAddr(addr),
- Err: errors.New(err.String()),
- }
- }
-
- return NewConn(&wq, ep), nil
-}
-
-// A PacketConn is a wrapper around a tcpip endpoint that implements
-// net.PacketConn.
-type PacketConn struct {
- deadlineTimer
-
- stack *stack.Stack
- ep tcpip.Endpoint
- wq *waiter.Queue
-}
-
-// DialUDP creates a new PacketConn.
-//
-// If laddr is nil, a local address is automatically chosen.
-//
-// If raddr is nil, the PacketConn is left unconnected.
-func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
- if err != nil {
- return nil, errors.New(err.String())
- }
-
- if laddr != nil {
- if err := ep.Bind(*laddr); err != nil {
- ep.Close()
- return nil, &net.OpError{
- Op: "bind",
- Net: "udp",
- Addr: fullToUDPAddr(*laddr),
- Err: errors.New(err.String()),
- }
- }
- }
-
- c := PacketConn{
- stack: s,
- ep: ep,
- wq: &wq,
- }
- c.deadlineTimer.init()
-
- if raddr != nil {
- if err := c.ep.Connect(*raddr); err != nil {
- c.ep.Close()
- return nil, &net.OpError{
- Op: "connect",
- Net: "udp",
- Addr: fullToUDPAddr(*raddr),
- Err: errors.New(err.String()),
- }
- }
- }
-
- return &c, nil
-}
-
-func (c *PacketConn) newOpError(op string, err error) *net.OpError {
- return c.newRemoteOpError(op, nil, err)
-}
-
-func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
- return &net.OpError{
- Op: op,
- Net: "udp",
- Source: c.LocalAddr(),
- Addr: remote,
- Err: err,
- }
-}
-
-// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *PacketConn) RemoteAddr() net.Addr {
- a, err := c.ep.GetRemoteAddress()
- if err != nil {
- return nil
- }
- return fullToTCPAddr(a)
-}
-
-// Read implements net.Conn.Read
-func (c *PacketConn) Read(b []byte) (int, error) {
- bytesRead, _, err := c.ReadFrom(b)
- return bytesRead, err
-}
-
-// ReadFrom implements net.PacketConn.ReadFrom.
-func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
- deadline := c.readCancel()
-
- var addr tcpip.FullAddress
- read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false)
- if err != nil {
- return 0, nil, err
- }
-
- return copy(b, read), fullToUDPAddr(addr), nil
-}
-
-func (c *PacketConn) Write(b []byte) (int, error) {
- return c.WriteTo(b, nil)
-}
-
-// WriteTo implements net.PacketConn.WriteTo.
-func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
- deadline := c.writeCancel()
-
- // Check if deadline has already expired.
- select {
- case <-deadline:
- return 0, c.newRemoteOpError("write", addr, &timeoutError{})
- default:
- }
-
- // If we're being called by Write, there is no addr
- wopts := tcpip.WriteOptions{}
- if addr != nil {
- ua := addr.(*net.UDPAddr)
- wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
- }
-
- v := buffer.NewView(len(b))
- copy(v, b)
-
- n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
- if resCh != nil {
- select {
- case <-deadline:
- return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
- case <-resCh:
- }
-
- n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
- }
-
- if err == tcpip.ErrWouldBlock {
- // Create wait queue entry that notifies a channel.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&waitEntry, waiter.EventOut)
- defer c.wq.EventUnregister(&waitEntry)
- for {
- select {
- case <-deadline:
- return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
- case <-notifyCh:
- }
-
- n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
- if err != tcpip.ErrWouldBlock {
- break
- }
- }
- }
-
- if err == nil {
- return int(n), nil
- }
-
- return int(n), c.newRemoteOpError("write", addr, errors.New(err.String()))
-}
-
-// Close implements net.PacketConn.Close.
-func (c *PacketConn) Close() error {
- c.ep.Close()
- return nil
-}
-
-// LocalAddr implements net.PacketConn.LocalAddr.
-func (c *PacketConn) LocalAddr() net.Addr {
- a, err := c.ep.GetLocalAddress()
- if err != nil {
- return nil
- }
- return fullToUDPAddr(a)
-}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
deleted file mode 100644
index ee077ae83..000000000
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ /dev/null
@@ -1,683 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package gonet
-
-import (
- "context"
- "fmt"
- "io"
- "net"
- "reflect"
- "strings"
- "testing"
- "time"
-
- "golang.org/x/net/nettest"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- NICID = 1
-)
-
-func TestTimeouts(t *testing.T) {
- nc := NewConn(nil, nil)
- dlfs := []struct {
- name string
- f func(time.Time) error
- }{
- {"SetDeadline", nc.SetDeadline},
- {"SetReadDeadline", nc.SetReadDeadline},
- {"SetWriteDeadline", nc.SetWriteDeadline},
- }
-
- for _, dlf := range dlfs {
- if err := dlf.f(time.Time{}); err != nil {
- t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil)
- }
- }
-}
-
-func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
- // Create the stack and add a NIC.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
- })
-
- if err := s.CreateNIC(NICID, loopback.New()); err != nil {
- return nil, err
- }
-
- // Add default route.
- s.SetRouteTable([]tcpip.Route{
- // IPv4
- {
- Destination: header.IPv4EmptySubnet,
- NIC: NICID,
- },
-
- // IPv6
- {
- Destination: header.IPv6EmptySubnet,
- NIC: NICID,
- },
- })
-
- return s, nil
-}
-
-type testConnection struct {
- wq *waiter.Queue
- e *waiter.Entry
- ch chan struct{}
- ep tcpip.Endpoint
-}
-
-func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
- wq := &waiter.Queue{}
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
-
- entry, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&entry, waiter.EventOut)
-
- err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
- <-ch
- err = ep.GetSockOpt(tcpip.ErrorOption{})
- }
- if err != nil {
- return nil, err
- }
-
- wq.EventUnregister(&entry)
- wq.EventRegister(&entry, waiter.EventIn)
-
- return &testConnection{wq, &entry, ch, ep}, nil
-}
-
-func (c *testConnection) close() {
- c.wq.EventUnregister(c.e)
- c.ep.Close()
-}
-
-// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock.
-func TestCloseReader(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
-
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
- if e != nil {
- t.Fatalf("NewListener() = %v", e)
- }
- done := make(chan struct{})
- go func() {
- defer close(done)
- c, err := l.Accept()
- if err != nil {
- t.Fatalf("l.Accept() = %v", err)
- }
-
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.Close()
- })
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- if n != 0 || err != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err)
- }
- }()
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(5 * time.Second):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when
-// using tcp.Forwarder.
-func TestCloseReaderWithForwarder(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- done := make(chan struct{})
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- defer close(done)
-
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
- r.Complete(false)
-
- c := NewConn(&wq, ep)
-
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.Close()
- })
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if n != 0 || e != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e)
- }
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(5 * time.Second):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-func TestCloseRead(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
- r.Complete(false)
-
- c := NewConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
- })
-
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- tc, terr := connect(s, addr)
- if terr != nil {
- t.Fatalf("connect() = %v", terr)
- }
- c := NewConn(tc.wq, tc.ep)
-
- if err := c.CloseRead(); err != nil {
- t.Errorf("c.CloseRead() = %v", err)
- }
-
- buf := make([]byte, 256)
- if n, err := c.Read(buf); err != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err)
- }
-
- if n, err := c.Write([]byte("abc123")); n != 6 || err != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err)
- }
-}
-
-func TestCloseWrite(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
- r.Complete(false)
-
- c := NewConn(&wq, ep)
-
- n, e := c.Read(make([]byte, 256))
- if n != 0 || e != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); n != 6 || e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
- })
-
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- tc, terr := connect(s, addr)
- if terr != nil {
- t.Fatalf("connect() = %v", terr)
- }
- c := NewConn(tc.wq, tc.ep)
-
- if err := c.CloseWrite(); err != nil {
- t.Errorf("c.CloseWrite() = %v", err)
- }
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- if err != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err)
- }
-
- n, err = c.Write([]byte("abc123"))
- got, ok := err.(*net.OpError)
- want := "endpoint is closed for send"
- if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) {
- t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want)
- }
-}
-
-func TestUDPForwarder(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
-
- ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
- ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
- addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
-
- done := make(chan struct{})
- fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
- defer close(done)
-
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
-
- c := NewConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil {
- t.Errorf("c.Read() = %v", e)
- }
-
- if _, e := c.Write(buf[:n]); e != nil {
- t.Errorf("c.Write() = %v", e)
- }
- })
- s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
-
- c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- sent := "abc123"
- sendAddr := fullToUDPAddr(addr1)
- if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
- t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
- }
-
- buf := make([]byte, 256)
- n, recvAddr, err := c2.ReadFrom(buf)
- if err != nil || recvAddr.String() != sendAddr.String() {
- t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil)
- }
-}
-
-// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
-func TestDeadlineChange(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
-
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
- if e != nil {
- t.Fatalf("NewListener() = %v", e)
- }
- done := make(chan struct{})
- go func() {
- defer close(done)
- c, err := l.Accept()
- if err != nil {
- t.Fatalf("l.Accept() = %v", err)
- }
-
- c.SetDeadline(time.Now().Add(time.Minute))
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.SetDeadline(time.Now().Add(time.Millisecond * 10))
- })
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- got, ok := err.(*net.OpError)
- want := "i/o timeout"
- if n != 0 || !ok || got.Err == nil || got.Err.Error() != want {
- t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want)
- }
- }()
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(time.Millisecond * 500):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-func TestPacketConnTransfer(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
-
- ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
- ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
- addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
-
- c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 4):", err)
- }
- c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- sent := "abc123"
- sendAddr := fullToUDPAddr(addr2)
- if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
- t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
- }
- recv := make([]byte, len(sent))
- n, recvAddr, err := c2.ReadFrom(recv)
- if err != nil || n != len(recv) {
- t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil)
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("got recv = %q, want = %q", recv, sent)
- }
-
- if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) {
- t.Errorf("got recvAddr = %v, want = %v", recvAddr, want)
- }
-
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
-}
-
-func TestConnectedPacketConnTransfer(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
-
- c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 4):", err)
- }
- c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- sent := "abc123"
- if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) {
- t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil)
- }
- recv := make([]byte, len(sent))
- n, err := c1.Read(recv)
- if err != nil || n != len(recv) {
- t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil)
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("got recv = %q, want = %q", recv, sent)
- }
-
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
-}
-
-func makePipe() (c1, c2 net.Conn, stop func(), err error) {
- s, e := newLoopbackStack()
- if e != nil {
- return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e)
- }
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
-
- l, err := NewListener(s, addr, ipv4.ProtocolNumber)
- if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
- }
-
- c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
- if err != nil {
- l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
- }
-
- c2, err = l.Accept()
- if err != nil {
- l.Close()
- c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
- }
-
- stop = func() {
- c1.Close()
- c2.Close()
- }
-
- if err := l.Close(); err != nil {
- stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
- }
-
- return c1, c2, stop, nil
-}
-
-func TestTCPConnTransfer(t *testing.T) {
- c1, c2, _, err := makePipe()
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
- }()
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- const sent = "abc123"
-
- tests := []struct {
- name string
- c1 net.Conn
- c2 net.Conn
- }{
- {"connected to accepted", c1, c2},
- {"accepted to connected", c2, c1},
- }
-
- for _, test := range tests {
- if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) {
- t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil)
- continue
- }
-
- recv := make([]byte, len(sent))
- n, err := test.c2.Read(recv)
- if err != nil || n != len(recv) {
- t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil)
- continue
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent)
- }
- }
-}
-
-func TestTCPDialError(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
-
- _, err := DialTCP(s, addr, ipv4.ProtocolNumber)
- got, ok := err.(*net.OpError)
- want := tcpip.ErrNoRoute
- if !ok || got.Err.Error() != want.String() {
- t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
- }
-}
-
-func TestDialContextTCPCanceled(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- ctx := context.Background()
- ctx, cancel := context.WithCancel(ctx)
- cancel()
-
- if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled {
- t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled)
- }
-}
-
-func TestDialContextTCPTimeout(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- time.Sleep(time.Second)
- r.Complete(true)
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- ctx := context.Background()
- ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond))
- defer cancel()
-
- if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded {
- t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded)
- }
-}
-
-func TestNetTest(t *testing.T) {
- nettest.TestConn(t, makePipe)
-}