summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/adapters/gonet
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/adapters/gonet')
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD36
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go605
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go424
3 files changed, 1065 insertions, 0 deletions
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
new file mode 100644
index 000000000..69cfc84ab
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -0,0 +1,36 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "gonet",
+ srcs = ["gonet.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/adapters/gonet",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "gonet_test",
+ size = "small",
+ srcs = ["gonet_test.go"],
+ embed = [":gonet"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@org_golang_x_net//nettest:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
new file mode 100644
index 000000000..96a2d670d
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -0,0 +1,605 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package gonet provides a Go net package compatible wrapper for a tcpip stack.
+package gonet
+
+import (
+ "errors"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+var errCanceled = errors.New("operation canceled")
+
+// timeoutError is how the net package reports timeouts.
+type timeoutError struct{}
+
+func (e *timeoutError) Error() string { return "i/o timeout" }
+func (e *timeoutError) Timeout() bool { return true }
+func (e *timeoutError) Temporary() bool { return true }
+
+// A Listener is a wrapper around a tcpip endpoint that implements
+// net.Listener.
+type Listener struct {
+ stack *stack.Stack
+ ep tcpip.Endpoint
+ wq *waiter.Queue
+ cancel chan struct{}
+}
+
+// NewListener creates a new Listener.
+func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
+ // Create TCP endpoint, bind it, then start listening.
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ if err := ep.Bind(addr, nil); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ if err := ep.Listen(10); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "listen",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return &Listener{
+ stack: s,
+ ep: ep,
+ wq: &wq,
+ cancel: make(chan struct{}),
+ }, nil
+}
+
+// Close implements net.Listener.Close.
+func (l *Listener) Close() error {
+ l.ep.Close()
+ return nil
+}
+
+// Shutdown stops the HTTP server.
+func (l *Listener) Shutdown() {
+ l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ close(l.cancel) // broadcast cancellation
+}
+
+// Addr implements net.Listener.Addr.
+func (l *Listener) Addr() net.Addr {
+ a, err := l.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+type deadlineTimer struct {
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ readTimer *time.Timer
+ readCancelCh chan struct{}
+ writeTimer *time.Timer
+ writeCancelCh chan struct{}
+}
+
+func (d *deadlineTimer) init() {
+ d.readCancelCh = make(chan struct{})
+ d.writeCancelCh = make(chan struct{})
+}
+
+func (d *deadlineTimer) readCancel() <-chan struct{} {
+ d.mu.Lock()
+ c := d.readCancelCh
+ d.mu.Unlock()
+ return c
+}
+func (d *deadlineTimer) writeCancel() <-chan struct{} {
+ d.mu.Lock()
+ c := d.writeCancelCh
+ d.mu.Unlock()
+ return c
+}
+
+// setDeadline contains the shared logic for setting a deadline.
+//
+// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and
+// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and
+// deadlineTimer.writeTimer.
+//
+// setDeadline must only be called while holding d.mu.
+func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) {
+ if *timer != nil && !(*timer).Stop() {
+ *cancelCh = make(chan struct{})
+ }
+
+ // Create a new channel if we already closed it due to setting an already
+ // expired time. We won't race with the timer because we already handled
+ // that above.
+ select {
+ case <-*cancelCh:
+ *cancelCh = make(chan struct{})
+ default:
+ }
+
+ // "A zero value for t means I/O operations will not time out."
+ // - net.Conn.SetDeadline
+ if t.IsZero() {
+ return
+ }
+
+ timeout := t.Sub(time.Now())
+ if timeout <= 0 {
+ close(*cancelCh)
+ return
+ }
+
+ // Timer.Stop returns whether or not the AfterFunc has started, but
+ // does not indicate whether or not it has completed. Make a copy of
+ // the cancel channel to prevent this code from racing with the next
+ // call of setDeadline replacing *cancelCh.
+ ch := *cancelCh
+ *timer = time.AfterFunc(timeout, func() {
+ close(ch)
+ })
+}
+
+// SetReadDeadline implements net.Conn.SetReadDeadline and
+// net.PacketConn.SetReadDeadline.
+func (d *deadlineTimer) SetReadDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.readCancelCh, &d.readTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// SetWriteDeadline implements net.Conn.SetWriteDeadline and
+// net.PacketConn.SetWriteDeadline.
+func (d *deadlineTimer) SetWriteDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline.
+func (d *deadlineTimer) SetDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.readCancelCh, &d.readTimer, t)
+ d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
+// interface.
+type Conn struct {
+ deadlineTimer
+
+ wq *waiter.Queue
+ ep tcpip.Endpoint
+
+ // readMu serializes reads and implicitly protects read.
+ //
+ // Lock ordering:
+ // If both readMu and deadlineTimer.mu are to be used in a single
+ // request, readMu must be acquired before deadlineTimer.mu.
+ readMu sync.Mutex
+
+ // read contains bytes that have been read from the endpoint,
+ // but haven't yet been returned.
+ read buffer.View
+}
+
+// NewConn creates a new Conn.
+func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
+ c := &Conn{
+ wq: wq,
+ ep: ep,
+ }
+ c.deadlineTimer.init()
+ return c
+}
+
+// Accept implements net.Conn.Accept.
+func (l *Listener) Accept() (net.Conn, error) {
+ n, wq, err := l.ep.Accept()
+
+ if err == tcpip.ErrWouldBlock {
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ l.wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer l.wq.EventUnregister(&waitEntry)
+
+ for {
+ n, wq, err = l.ep.Accept()
+
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+
+ select {
+ case <-l.cancel:
+ return nil, errCanceled
+ case <-notifyCh:
+ }
+ }
+ }
+
+ if err != nil {
+ return nil, &net.OpError{
+ Op: "accept",
+ Net: "tcp",
+ Addr: l.Addr(),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return NewConn(wq, n), nil
+}
+
+type opErrorer interface {
+ newOpError(op string, err error) *net.OpError
+}
+
+// commonRead implements the common logic between net.Conn.Read and
+// net.PacketConn.ReadFrom.
+func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) ([]byte, error) {
+ read, err := ep.Read(addr)
+
+ if err == tcpip.ErrWouldBlock {
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer wq.EventUnregister(&waitEntry)
+ for {
+ read, err = ep.Read(addr)
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+ select {
+ case <-deadline:
+ return nil, errorer.newOpError("read", &timeoutError{})
+ case <-notifyCh:
+ }
+ }
+ }
+
+ if err == tcpip.ErrClosedForReceive {
+ return nil, io.EOF
+ }
+
+ if err != nil {
+ return nil, errorer.newOpError("read", errors.New(err.String()))
+ }
+
+ return read, nil
+}
+
+// Read implements net.Conn.Read.
+func (c *Conn) Read(b []byte) (int, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
+
+ deadline := c.readCancel()
+
+ // Check if deadline has already expired.
+ select {
+ case <-deadline:
+ return 0, c.newOpError("read", &timeoutError{})
+ default:
+ }
+
+ if len(c.read) == 0 {
+ var err error
+ c.read, err = commonRead(c.ep, c.wq, deadline, nil, c)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ n := copy(b, c.read)
+ c.read.TrimFront(n)
+ if len(c.read) == 0 {
+ c.read = nil
+ }
+ return n, nil
+}
+
+// Write implements net.Conn.Write.
+func (c *Conn) Write(b []byte) (int, error) {
+ deadline := c.writeCancel()
+
+ // Check if deadlineTimer has already expired.
+ select {
+ case <-deadline:
+ return 0, c.newOpError("write", &timeoutError{})
+ default:
+ }
+
+ v := buffer.NewView(len(b))
+ copy(v, b)
+
+ // We must handle two soft failure conditions simultaneously:
+ // 1. Write may write nothing and return tcpip.ErrWouldBlock.
+ // If this happens, we need to register for notifications if we have
+ // not already and wait to try again.
+ // 2. Write may write fewer than the full number of bytes and return
+ // without error. In this case we need to try writing the remaining
+ // bytes again. I do not need to register for notifications.
+ //
+ // What is more, these two soft failure conditions can be interspersed.
+ // There is no guarantee that all of the condition #1s will occur before
+ // all of the condition #2s or visa-versa.
+ var (
+ err *tcpip.Error
+ nbytes int
+ reg bool
+ notifyCh chan struct{}
+ )
+ for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
+ if err == tcpip.ErrWouldBlock {
+ if !reg {
+ // Only register once.
+ reg = true
+
+ // Create wait queue entry that notifies a channel.
+ var waitEntry waiter.Entry
+ waitEntry, notifyCh = waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.wq.EventUnregister(&waitEntry)
+ } else {
+ // Don't wait immediately after registration in case more data
+ // became available between when we last checked and when we setup
+ // the notification.
+ select {
+ case <-deadline:
+ return nbytes, c.newOpError("write", &timeoutError{})
+ case <-notifyCh:
+ }
+ }
+ }
+
+ var n uintptr
+ n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ nbytes += int(n)
+ v.TrimFront(int(n))
+ }
+
+ if err == nil {
+ return nbytes, nil
+ }
+
+ return nbytes, c.newOpError("write", errors.New(err.String()))
+}
+
+// Close implements net.Conn.Close.
+func (c *Conn) Close() error {
+ c.ep.Close()
+ return nil
+}
+
+// LocalAddr implements net.Conn.LocalAddr.
+func (c *Conn) LocalAddr() net.Addr {
+ a, err := c.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+// RemoteAddr implements net.Conn.RemoteAddr.
+func (c *Conn) RemoteAddr() net.Addr {
+ a, err := c.ep.GetRemoteAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+func (c *Conn) newOpError(op string, err error) *net.OpError {
+ return &net.OpError{
+ Op: op,
+ Net: "tcp",
+ Source: c.LocalAddr(),
+ Addr: c.RemoteAddr(),
+ Err: err,
+ }
+}
+
+func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr {
+ return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
+}
+
+func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
+ return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
+}
+
+// DialTCP creates a new TCP Conn connected to the specified address.
+func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+ // Create TCP endpoint, then connect.
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ // Create wait queue entry that notifies a channel.
+ //
+ // We do this unconditionally as Connect will always return an error.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer wq.EventUnregister(&waitEntry)
+
+ err = ep.Connect(addr)
+ if err == tcpip.ErrConnectStarted {
+ <-notifyCh
+ err = ep.GetSockOpt(tcpip.ErrorOption{})
+ }
+ if err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return NewConn(&wq, ep), nil
+}
+
+// A PacketConn is a wrapper around a tcpip endpoint that implements
+// net.PacketConn.
+type PacketConn struct {
+ deadlineTimer
+
+ stack *stack.Stack
+ ep tcpip.Endpoint
+ wq *waiter.Queue
+}
+
+// NewPacketConn creates a new PacketConn.
+func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
+ // Create UDP endpoint and bind it.
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ if err := ep.Bind(addr, nil); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "udp",
+ Addr: fullToUDPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ c := &PacketConn{
+ stack: s,
+ ep: ep,
+ wq: &wq,
+ }
+ c.deadlineTimer.init()
+ return c, nil
+}
+
+func (c *PacketConn) newOpError(op string, err error) *net.OpError {
+ return c.newRemoteOpError(op, nil, err)
+}
+
+func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
+ return &net.OpError{
+ Op: op,
+ Net: "udp",
+ Source: c.LocalAddr(),
+ Addr: remote,
+ Err: err,
+ }
+}
+
+// ReadFrom implements net.PacketConn.ReadFrom.
+func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
+ deadline := c.readCancel()
+
+ // Check if deadline has already expired.
+ select {
+ case <-deadline:
+ return 0, nil, c.newOpError("read", &timeoutError{})
+ default:
+ }
+
+ var addr tcpip.FullAddress
+ read, err := commonRead(c.ep, c.wq, deadline, &addr, c)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return copy(b, read), fullToUDPAddr(addr), nil
+}
+
+// WriteTo implements net.PacketConn.WriteTo.
+func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+ deadline := c.writeCancel()
+
+ // Check if deadline has already expired.
+ select {
+ case <-deadline:
+ return 0, c.newRemoteOpError("write", addr, &timeoutError{})
+ default:
+ }
+
+ ua := addr.(*net.UDPAddr)
+ fullAddr := tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
+
+ v := buffer.NewView(len(b))
+ copy(v, b)
+
+ wopts := tcpip.WriteOptions{To: &fullAddr}
+ n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
+
+ if err == tcpip.ErrWouldBlock {
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.wq.EventUnregister(&waitEntry)
+ for {
+ n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+ select {
+ case <-deadline:
+ return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
+ case <-notifyCh:
+ }
+ }
+ }
+
+ if err == nil {
+ return int(n), nil
+ }
+
+ return int(n), c.newRemoteOpError("write", addr, errors.New(err.String()))
+}
+
+// Close implements net.PacketConn.Close.
+func (c *PacketConn) Close() error {
+ c.ep.Close()
+ return nil
+}
+
+// LocalAddr implements net.PacketConn.LocalAddr.
+func (c *PacketConn) LocalAddr() net.Addr {
+ a, err := c.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToUDPAddr(a)
+}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
new file mode 100644
index 000000000..2f86469eb
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -0,0 +1,424 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gonet
+
+import (
+ "fmt"
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+
+ "golang.org/x/net/nettest"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ NICID = 1
+)
+
+func TestTimeouts(t *testing.T) {
+ nc := NewConn(nil, nil)
+ dlfs := []struct {
+ name string
+ f func(time.Time) error
+ }{
+ {"SetDeadline", nc.SetDeadline},
+ {"SetReadDeadline", nc.SetReadDeadline},
+ {"SetWriteDeadline", nc.SetWriteDeadline},
+ }
+
+ for _, dlf := range dlfs {
+ if err := dlf.f(time.Time{}); err != nil {
+ t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil)
+ }
+ }
+}
+
+func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
+ // Create the stack and add a NIC.
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName})
+
+ if err := s.CreateNIC(NICID, loopback.New()); err != nil {
+ return nil, err
+ }
+
+ // Add default route.
+ s.SetRouteTable([]tcpip.Route{
+ // IPv4
+ {
+ Destination: tcpip.Address(strings.Repeat("\x00", 4)),
+ Mask: tcpip.Address(strings.Repeat("\x00", 4)),
+ Gateway: "",
+ NIC: NICID,
+ },
+
+ // IPv6
+ {
+ Destination: tcpip.Address(strings.Repeat("\x00", 16)),
+ Mask: tcpip.Address(strings.Repeat("\x00", 16)),
+ Gateway: "",
+ NIC: NICID,
+ },
+ })
+
+ return s, nil
+}
+
+type testConnection struct {
+ wq *waiter.Queue
+ e *waiter.Entry
+ ch chan struct{}
+ ep tcpip.Endpoint
+}
+
+func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
+ wq := &waiter.Queue{}
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+
+ entry, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&entry, waiter.EventOut)
+
+ err = ep.Connect(addr)
+ if err == tcpip.ErrConnectStarted {
+ <-ch
+ err = ep.GetSockOpt(tcpip.ErrorOption{})
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ wq.EventUnregister(&entry)
+ wq.EventRegister(&entry, waiter.EventIn)
+
+ return &testConnection{wq, &entry, ch, ep}, nil
+}
+
+func (c *testConnection) close() {
+ c.wq.EventUnregister(c.e)
+ c.ep.Close()
+}
+
+// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock.
+func TestCloseReader(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ if e != nil {
+ t.Fatalf("NewListener() = %v", e)
+ }
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatalf("l.Accept() = %v", err)
+ }
+
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.Close()
+ })
+
+ buf := make([]byte, 256)
+ n, err := c.Read(buf)
+ got, ok := err.(*net.OpError)
+ want := tcpip.ErrConnectionAborted
+ if n != 0 || !ok || got.Err.Error() != want.String() {
+ t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want)
+ }
+ }()
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when
+// using tcp.Forwarder.
+func TestCloseReaderWithForwarder(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ done := make(chan struct{})
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ defer close(done)
+
+ var wq waiter.Queue
+ ep, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ defer ep.Close()
+ r.Complete(false)
+
+ c := NewConn(&wq, ep)
+
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.Close()
+ })
+
+ buf := make([]byte, 256)
+ n, e := c.Read(buf)
+ got, ok := e.(*net.OpError)
+ want := tcpip.ErrConnectionAborted
+ if n != 0 || !ok || got.Err.Error() != want.String() {
+ t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want)
+ }
+ })
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
+func TestDeadlineChange(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ if e != nil {
+ t.Fatalf("NewListener() = %v", e)
+ }
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatalf("l.Accept() = %v", err)
+ }
+
+ c.SetDeadline(time.Now().Add(time.Minute))
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.SetDeadline(time.Now().Add(time.Millisecond * 10))
+ })
+
+ buf := make([]byte, 256)
+ n, err := c.Read(buf)
+ got, ok := err.(*net.OpError)
+ want := "i/o timeout"
+ if n != 0 || !ok || got.Err == nil || got.Err.Error() != want {
+ t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want)
+ }
+ }()
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(time.Millisecond * 500):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+func TestPacketConnTransfer(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+
+ ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr1 := tcpip.FullAddress{NICID, ip1, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
+ addr2 := tcpip.FullAddress{NICID, ip2, 11311}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+
+ c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("NewPacketConn(port 4):", err)
+ }
+ c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("NewPacketConn(port 5):", err)
+ }
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ sent := "abc123"
+ sendAddr := fullToUDPAddr(addr2)
+ if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
+ t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
+ }
+ recv := make([]byte, len(sent))
+ n, recvAddr, err := c2.ReadFrom(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil)
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("got recv = %q, want = %q", recv, sent)
+ }
+
+ if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) {
+ t.Errorf("got recvAddr = %v, want = %v", recvAddr, want)
+ }
+
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+}
+
+func makePipe() (c1, c2 net.Conn, stop func(), err error) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+
+ l, err := NewListener(s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ }
+
+ c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ l.Close()
+ return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ }
+
+ c2, err = l.Accept()
+ if err != nil {
+ l.Close()
+ c1.Close()
+ return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ }
+
+ stop = func() {
+ c1.Close()
+ c2.Close()
+ }
+
+ if err := l.Close(); err != nil {
+ stop()
+ return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ }
+
+ return c1, c2, stop, nil
+}
+
+func TestTCPConnTransfer(t *testing.T) {
+ c1, c2, _, err := makePipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+ }()
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ const sent = "abc123"
+
+ tests := []struct {
+ name string
+ c1 net.Conn
+ c2 net.Conn
+ }{
+ {"connected to accepted", c1, c2},
+ {"accepted to connected", c2, c1},
+ }
+
+ for _, test := range tests {
+ if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) {
+ t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil)
+ continue
+ }
+
+ recv := make([]byte, len(sent))
+ n, err := test.c2.Read(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil)
+ continue
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent)
+ }
+ }
+}
+
+func TestTCPDialError(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+
+ _, err := DialTCP(s, addr, ipv4.ProtocolNumber)
+ got, ok := err.(*net.OpError)
+ want := tcpip.ErrNoRoute
+ if !ok || got.Err.Error() != want.String() {
+ t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
+ }
+}
+
+func TestNetTest(t *testing.T) {
+ nettest.TestConn(t, makePipe)
+}