diff options
Diffstat (limited to 'pkg/tcpip/adapters/gonet')
-rw-r--r-- | pkg/tcpip/adapters/gonet/BUILD | 36 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 605 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet_test.go | 424 |
3 files changed, 1065 insertions, 0 deletions
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD new file mode 100644 index 000000000..69cfc84ab --- /dev/null +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -0,0 +1,36 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "gonet", + srcs = ["gonet.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/adapters/gonet", + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + ], +) + +go_test( + name = "gonet_test", + size = "small", + srcs = ["gonet_test.go"], + embed = [":gonet"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@org_golang_x_net//nettest:go_default_library", + ], +) diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go new file mode 100644 index 000000000..96a2d670d --- /dev/null +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -0,0 +1,605 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package gonet provides a Go net package compatible wrapper for a tcpip stack. +package gonet + +import ( + "errors" + "io" + "net" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +var errCanceled = errors.New("operation canceled") + +// timeoutError is how the net package reports timeouts. +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +// A Listener is a wrapper around a tcpip endpoint that implements +// net.Listener. +type Listener struct { + stack *stack.Stack + ep tcpip.Endpoint + wq *waiter.Queue + cancel chan struct{} +} + +// NewListener creates a new Listener. +func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) { + // Create TCP endpoint, bind it, then start listening. + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) + if err != nil { + return nil, errors.New(err.String()) + } + + if err := ep.Bind(addr, nil); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "bind", + Net: "tcp", + Addr: fullToTCPAddr(addr), + Err: errors.New(err.String()), + } + } + + if err := ep.Listen(10); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "listen", + Net: "tcp", + Addr: fullToTCPAddr(addr), + Err: errors.New(err.String()), + } + } + + return &Listener{ + stack: s, + ep: ep, + wq: &wq, + cancel: make(chan struct{}), + }, nil +} + +// Close implements net.Listener.Close. +func (l *Listener) Close() error { + l.ep.Close() + return nil +} + +// Shutdown stops the HTTP server. +func (l *Listener) Shutdown() { + l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + close(l.cancel) // broadcast cancellation +} + +// Addr implements net.Listener.Addr. +func (l *Listener) Addr() net.Addr { + a, err := l.ep.GetLocalAddress() + if err != nil { + return nil + } + return fullToTCPAddr(a) +} + +type deadlineTimer struct { + // mu protects the fields below. + mu sync.Mutex + + readTimer *time.Timer + readCancelCh chan struct{} + writeTimer *time.Timer + writeCancelCh chan struct{} +} + +func (d *deadlineTimer) init() { + d.readCancelCh = make(chan struct{}) + d.writeCancelCh = make(chan struct{}) +} + +func (d *deadlineTimer) readCancel() <-chan struct{} { + d.mu.Lock() + c := d.readCancelCh + d.mu.Unlock() + return c +} +func (d *deadlineTimer) writeCancel() <-chan struct{} { + d.mu.Lock() + c := d.writeCancelCh + d.mu.Unlock() + return c +} + +// setDeadline contains the shared logic for setting a deadline. +// +// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and +// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and +// deadlineTimer.writeTimer. +// +// setDeadline must only be called while holding d.mu. +func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { + if *timer != nil && !(*timer).Stop() { + *cancelCh = make(chan struct{}) + } + + // Create a new channel if we already closed it due to setting an already + // expired time. We won't race with the timer because we already handled + // that above. + select { + case <-*cancelCh: + *cancelCh = make(chan struct{}) + default: + } + + // "A zero value for t means I/O operations will not time out." + // - net.Conn.SetDeadline + if t.IsZero() { + return + } + + timeout := t.Sub(time.Now()) + if timeout <= 0 { + close(*cancelCh) + return + } + + // Timer.Stop returns whether or not the AfterFunc has started, but + // does not indicate whether or not it has completed. Make a copy of + // the cancel channel to prevent this code from racing with the next + // call of setDeadline replacing *cancelCh. + ch := *cancelCh + *timer = time.AfterFunc(timeout, func() { + close(ch) + }) +} + +// SetReadDeadline implements net.Conn.SetReadDeadline and +// net.PacketConn.SetReadDeadline. +func (d *deadlineTimer) SetReadDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.readCancelCh, &d.readTimer, t) + d.mu.Unlock() + return nil +} + +// SetWriteDeadline implements net.Conn.SetWriteDeadline and +// net.PacketConn.SetWriteDeadline. +func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) + d.mu.Unlock() + return nil +} + +// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. +func (d *deadlineTimer) SetDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.readCancelCh, &d.readTimer, t) + d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) + d.mu.Unlock() + return nil +} + +// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn +// interface. +type Conn struct { + deadlineTimer + + wq *waiter.Queue + ep tcpip.Endpoint + + // readMu serializes reads and implicitly protects read. + // + // Lock ordering: + // If both readMu and deadlineTimer.mu are to be used in a single + // request, readMu must be acquired before deadlineTimer.mu. + readMu sync.Mutex + + // read contains bytes that have been read from the endpoint, + // but haven't yet been returned. + read buffer.View +} + +// NewConn creates a new Conn. +func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { + c := &Conn{ + wq: wq, + ep: ep, + } + c.deadlineTimer.init() + return c +} + +// Accept implements net.Conn.Accept. +func (l *Listener) Accept() (net.Conn, error) { + n, wq, err := l.ep.Accept() + + if err == tcpip.ErrWouldBlock { + // Create wait queue entry that notifies a channel. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + l.wq.EventRegister(&waitEntry, waiter.EventIn) + defer l.wq.EventUnregister(&waitEntry) + + for { + n, wq, err = l.ep.Accept() + + if err != tcpip.ErrWouldBlock { + break + } + + select { + case <-l.cancel: + return nil, errCanceled + case <-notifyCh: + } + } + } + + if err != nil { + return nil, &net.OpError{ + Op: "accept", + Net: "tcp", + Addr: l.Addr(), + Err: errors.New(err.String()), + } + } + + return NewConn(wq, n), nil +} + +type opErrorer interface { + newOpError(op string, err error) *net.OpError +} + +// commonRead implements the common logic between net.Conn.Read and +// net.PacketConn.ReadFrom. +func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) ([]byte, error) { + read, err := ep.Read(addr) + + if err == tcpip.ErrWouldBlock { + // Create wait queue entry that notifies a channel. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventIn) + defer wq.EventUnregister(&waitEntry) + for { + read, err = ep.Read(addr) + if err != tcpip.ErrWouldBlock { + break + } + select { + case <-deadline: + return nil, errorer.newOpError("read", &timeoutError{}) + case <-notifyCh: + } + } + } + + if err == tcpip.ErrClosedForReceive { + return nil, io.EOF + } + + if err != nil { + return nil, errorer.newOpError("read", errors.New(err.String())) + } + + return read, nil +} + +// Read implements net.Conn.Read. +func (c *Conn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + deadline := c.readCancel() + + // Check if deadline has already expired. + select { + case <-deadline: + return 0, c.newOpError("read", &timeoutError{}) + default: + } + + if len(c.read) == 0 { + var err error + c.read, err = commonRead(c.ep, c.wq, deadline, nil, c) + if err != nil { + return 0, err + } + } + + n := copy(b, c.read) + c.read.TrimFront(n) + if len(c.read) == 0 { + c.read = nil + } + return n, nil +} + +// Write implements net.Conn.Write. +func (c *Conn) Write(b []byte) (int, error) { + deadline := c.writeCancel() + + // Check if deadlineTimer has already expired. + select { + case <-deadline: + return 0, c.newOpError("write", &timeoutError{}) + default: + } + + v := buffer.NewView(len(b)) + copy(v, b) + + // We must handle two soft failure conditions simultaneously: + // 1. Write may write nothing and return tcpip.ErrWouldBlock. + // If this happens, we need to register for notifications if we have + // not already and wait to try again. + // 2. Write may write fewer than the full number of bytes and return + // without error. In this case we need to try writing the remaining + // bytes again. I do not need to register for notifications. + // + // What is more, these two soft failure conditions can be interspersed. + // There is no guarantee that all of the condition #1s will occur before + // all of the condition #2s or visa-versa. + var ( + err *tcpip.Error + nbytes int + reg bool + notifyCh chan struct{} + ) + for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { + if err == tcpip.ErrWouldBlock { + if !reg { + // Only register once. + reg = true + + // Create wait queue entry that notifies a channel. + var waitEntry waiter.Entry + waitEntry, notifyCh = waiter.NewChannelEntry(nil) + c.wq.EventRegister(&waitEntry, waiter.EventOut) + defer c.wq.EventUnregister(&waitEntry) + } else { + // Don't wait immediately after registration in case more data + // became available between when we last checked and when we setup + // the notification. + select { + case <-deadline: + return nbytes, c.newOpError("write", &timeoutError{}) + case <-notifyCh: + } + } + } + + var n uintptr + n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + nbytes += int(n) + v.TrimFront(int(n)) + } + + if err == nil { + return nbytes, nil + } + + return nbytes, c.newOpError("write", errors.New(err.String())) +} + +// Close implements net.Conn.Close. +func (c *Conn) Close() error { + c.ep.Close() + return nil +} + +// LocalAddr implements net.Conn.LocalAddr. +func (c *Conn) LocalAddr() net.Addr { + a, err := c.ep.GetLocalAddress() + if err != nil { + return nil + } + return fullToTCPAddr(a) +} + +// RemoteAddr implements net.Conn.RemoteAddr. +func (c *Conn) RemoteAddr() net.Addr { + a, err := c.ep.GetRemoteAddress() + if err != nil { + return nil + } + return fullToTCPAddr(a) +} + +func (c *Conn) newOpError(op string, err error) *net.OpError { + return &net.OpError{ + Op: op, + Net: "tcp", + Source: c.LocalAddr(), + Addr: c.RemoteAddr(), + Err: err, + } +} + +func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr { + return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} +} + +func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { + return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} +} + +// DialTCP creates a new TCP Conn connected to the specified address. +func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { + // Create TCP endpoint, then connect. + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) + if err != nil { + return nil, errors.New(err.String()) + } + + // Create wait queue entry that notifies a channel. + // + // We do this unconditionally as Connect will always return an error. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventOut) + defer wq.EventUnregister(&waitEntry) + + err = ep.Connect(addr) + if err == tcpip.ErrConnectStarted { + <-notifyCh + err = ep.GetSockOpt(tcpip.ErrorOption{}) + } + if err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "connect", + Net: "tcp", + Addr: fullToTCPAddr(addr), + Err: errors.New(err.String()), + } + } + + return NewConn(&wq, ep), nil +} + +// A PacketConn is a wrapper around a tcpip endpoint that implements +// net.PacketConn. +type PacketConn struct { + deadlineTimer + + stack *stack.Stack + ep tcpip.Endpoint + wq *waiter.Queue +} + +// NewPacketConn creates a new PacketConn. +func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { + // Create UDP endpoint and bind it. + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) + if err != nil { + return nil, errors.New(err.String()) + } + + if err := ep.Bind(addr, nil); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "bind", + Net: "udp", + Addr: fullToUDPAddr(addr), + Err: errors.New(err.String()), + } + } + + c := &PacketConn{ + stack: s, + ep: ep, + wq: &wq, + } + c.deadlineTimer.init() + return c, nil +} + +func (c *PacketConn) newOpError(op string, err error) *net.OpError { + return c.newRemoteOpError(op, nil, err) +} + +func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { + return &net.OpError{ + Op: op, + Net: "udp", + Source: c.LocalAddr(), + Addr: remote, + Err: err, + } +} + +// ReadFrom implements net.PacketConn.ReadFrom. +func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + deadline := c.readCancel() + + // Check if deadline has already expired. + select { + case <-deadline: + return 0, nil, c.newOpError("read", &timeoutError{}) + default: + } + + var addr tcpip.FullAddress + read, err := commonRead(c.ep, c.wq, deadline, &addr, c) + if err != nil { + return 0, nil, err + } + + return copy(b, read), fullToUDPAddr(addr), nil +} + +// WriteTo implements net.PacketConn.WriteTo. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + deadline := c.writeCancel() + + // Check if deadline has already expired. + select { + case <-deadline: + return 0, c.newRemoteOpError("write", addr, &timeoutError{}) + default: + } + + ua := addr.(*net.UDPAddr) + fullAddr := tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} + + v := buffer.NewView(len(b)) + copy(v, b) + + wopts := tcpip.WriteOptions{To: &fullAddr} + n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) + + if err == tcpip.ErrWouldBlock { + // Create wait queue entry that notifies a channel. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.wq.EventRegister(&waitEntry, waiter.EventOut) + defer c.wq.EventUnregister(&waitEntry) + for { + n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + if err != tcpip.ErrWouldBlock { + break + } + select { + case <-deadline: + return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) + case <-notifyCh: + } + } + } + + if err == nil { + return int(n), nil + } + + return int(n), c.newRemoteOpError("write", addr, errors.New(err.String())) +} + +// Close implements net.PacketConn.Close. +func (c *PacketConn) Close() error { + c.ep.Close() + return nil +} + +// LocalAddr implements net.PacketConn.LocalAddr. +func (c *PacketConn) LocalAddr() net.Addr { + a, err := c.ep.GetLocalAddress() + if err != nil { + return nil + } + return fullToUDPAddr(a) +} diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go new file mode 100644 index 000000000..2f86469eb --- /dev/null +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -0,0 +1,424 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonet + +import ( + "fmt" + "net" + "reflect" + "strings" + "testing" + "time" + + "golang.org/x/net/nettest" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + NICID = 1 +) + +func TestTimeouts(t *testing.T) { + nc := NewConn(nil, nil) + dlfs := []struct { + name string + f func(time.Time) error + }{ + {"SetDeadline", nc.SetDeadline}, + {"SetReadDeadline", nc.SetReadDeadline}, + {"SetWriteDeadline", nc.SetWriteDeadline}, + } + + for _, dlf := range dlfs { + if err := dlf.f(time.Time{}); err != nil { + t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil) + } + } +} + +func newLoopbackStack() (*stack.Stack, *tcpip.Error) { + // Create the stack and add a NIC. + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}) + + if err := s.CreateNIC(NICID, loopback.New()); err != nil { + return nil, err + } + + // Add default route. + s.SetRouteTable([]tcpip.Route{ + // IPv4 + { + Destination: tcpip.Address(strings.Repeat("\x00", 4)), + Mask: tcpip.Address(strings.Repeat("\x00", 4)), + Gateway: "", + NIC: NICID, + }, + + // IPv6 + { + Destination: tcpip.Address(strings.Repeat("\x00", 16)), + Mask: tcpip.Address(strings.Repeat("\x00", 16)), + Gateway: "", + NIC: NICID, + }, + }) + + return s, nil +} + +type testConnection struct { + wq *waiter.Queue + e *waiter.Entry + ch chan struct{} + ep tcpip.Endpoint +} + +func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { + wq := &waiter.Queue{} + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + + entry, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&entry, waiter.EventOut) + + err = ep.Connect(addr) + if err == tcpip.ErrConnectStarted { + <-ch + err = ep.GetSockOpt(tcpip.ErrorOption{}) + } + if err != nil { + return nil, err + } + + wq.EventUnregister(&entry) + wq.EventRegister(&entry, waiter.EventIn) + + return &testConnection{wq, &entry, ch, ep}, nil +} + +func (c *testConnection) close() { + c.wq.EventUnregister(c.e) + c.ep.Close() +} + +// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock. +func TestCloseReader(t *testing.T) { + s, err := newLoopbackStack() + if err != nil { + t.Fatalf("newLoopbackStack() = %v", err) + } + + addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} + + s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + + l, e := NewListener(s, addr, ipv4.ProtocolNumber) + if e != nil { + t.Fatalf("NewListener() = %v", e) + } + done := make(chan struct{}) + go func() { + defer close(done) + c, err := l.Accept() + if err != nil { + t.Fatalf("l.Accept() = %v", err) + } + + // Give c.Read() a chance to block before closing the connection. + time.AfterFunc(time.Millisecond*50, func() { + c.Close() + }) + + buf := make([]byte, 256) + n, err := c.Read(buf) + got, ok := err.(*net.OpError) + want := tcpip.ErrConnectionAborted + if n != 0 || !ok || got.Err.Error() != want.String() { + t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + } + }() + sender, err := connect(s, addr) + if err != nil { + t.Fatalf("connect() = %v", err) + } + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Errorf("c.Read() didn't unblock") + } + sender.close() +} + +// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when +// using tcp.Forwarder. +func TestCloseReaderWithForwarder(t *testing.T) { + s, err := newLoopbackStack() + if err != nil { + t.Fatalf("newLoopbackStack() = %v", err) + } + + addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + + done := make(chan struct{}) + + fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { + defer close(done) + + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + t.Fatalf("r.CreateEndpoint() = %v", err) + } + defer ep.Close() + r.Complete(false) + + c := NewConn(&wq, ep) + + // Give c.Read() a chance to block before closing the connection. + time.AfterFunc(time.Millisecond*50, func() { + c.Close() + }) + + buf := make([]byte, 256) + n, e := c.Read(buf) + got, ok := e.(*net.OpError) + want := tcpip.ErrConnectionAborted + if n != 0 || !ok || got.Err.Error() != want.String() { + t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + } + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) + + sender, err := connect(s, addr) + if err != nil { + t.Fatalf("connect() = %v", err) + } + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Errorf("c.Read() didn't unblock") + } + sender.close() +} + +// TestDeadlineChange tests that changing the deadline affects currently blocked reads. +func TestDeadlineChange(t *testing.T) { + s, err := newLoopbackStack() + if err != nil { + t.Fatalf("newLoopbackStack() = %v", err) + } + + addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} + + s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + + l, e := NewListener(s, addr, ipv4.ProtocolNumber) + if e != nil { + t.Fatalf("NewListener() = %v", e) + } + done := make(chan struct{}) + go func() { + defer close(done) + c, err := l.Accept() + if err != nil { + t.Fatalf("l.Accept() = %v", err) + } + + c.SetDeadline(time.Now().Add(time.Minute)) + // Give c.Read() a chance to block before closing the connection. + time.AfterFunc(time.Millisecond*50, func() { + c.SetDeadline(time.Now().Add(time.Millisecond * 10)) + }) + + buf := make([]byte, 256) + n, err := c.Read(buf) + got, ok := err.(*net.OpError) + want := "i/o timeout" + if n != 0 || !ok || got.Err == nil || got.Err.Error() != want { + t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want) + } + }() + sender, err := connect(s, addr) + if err != nil { + t.Fatalf("connect() = %v", err) + } + + select { + case <-done: + case <-time.After(time.Millisecond * 500): + t.Errorf("c.Read() didn't unblock") + } + sender.close() +} + +func TestPacketConnTransfer(t *testing.T) { + s, e := newLoopbackStack() + if e != nil { + t.Fatalf("newLoopbackStack() = %v", e) + } + + ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) + addr1 := tcpip.FullAddress{NICID, ip1, 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) + addr2 := tcpip.FullAddress{NICID, ip2, 11311} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + + c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("NewPacketConn(port 4):", err) + } + c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("NewPacketConn(port 5):", err) + } + + c1.SetDeadline(time.Now().Add(time.Second)) + c2.SetDeadline(time.Now().Add(time.Second)) + + sent := "abc123" + sendAddr := fullToUDPAddr(addr2) + if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { + t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) + } + recv := make([]byte, len(sent)) + n, recvAddr, err := c2.ReadFrom(recv) + if err != nil || n != len(recv) { + t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil) + } + + if recv := string(recv); recv != sent { + t.Errorf("got recv = %q, want = %q", recv, sent) + } + + if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) { + t.Errorf("got recvAddr = %v, want = %v", recvAddr, want) + } + + if err := c1.Close(); err != nil { + t.Error("c1.Close():", err) + } + if err := c2.Close(); err != nil { + t.Error("c2.Close():", err) + } +} + +func makePipe() (c1, c2 net.Conn, stop func(), err error) { + s, e := newLoopbackStack() + if e != nil { + return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e) + } + + ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) + addr := tcpip.FullAddress{NICID, ip, 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + + l, err := NewListener(s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, nil, nil, fmt.Errorf("NewListener: %v", err) + } + + c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) + if err != nil { + l.Close() + return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) + } + + c2, err = l.Accept() + if err != nil { + l.Close() + c1.Close() + return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) + } + + stop = func() { + c1.Close() + c2.Close() + } + + if err := l.Close(); err != nil { + stop() + return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) + } + + return c1, c2, stop, nil +} + +func TestTCPConnTransfer(t *testing.T) { + c1, c2, _, err := makePipe() + if err != nil { + t.Fatal(err) + } + defer func() { + if err := c1.Close(); err != nil { + t.Error("c1.Close():", err) + } + if err := c2.Close(); err != nil { + t.Error("c2.Close():", err) + } + }() + + c1.SetDeadline(time.Now().Add(time.Second)) + c2.SetDeadline(time.Now().Add(time.Second)) + + const sent = "abc123" + + tests := []struct { + name string + c1 net.Conn + c2 net.Conn + }{ + {"connected to accepted", c1, c2}, + {"accepted to connected", c2, c1}, + } + + for _, test := range tests { + if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) { + t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil) + continue + } + + recv := make([]byte, len(sent)) + n, err := test.c2.Read(recv) + if err != nil || n != len(recv) { + t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil) + continue + } + + if recv := string(recv); recv != sent { + t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent) + } + } +} + +func TestTCPDialError(t *testing.T) { + s, e := newLoopbackStack() + if e != nil { + t.Fatalf("newLoopbackStack() = %v", e) + } + + ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) + addr := tcpip.FullAddress{NICID, ip, 11211} + + _, err := DialTCP(s, addr, ipv4.ProtocolNumber) + got, ok := err.(*net.OpError) + want := tcpip.ErrNoRoute + if !ok || got.Err.Error() != want.String() { + t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute) + } +} + +func TestNetTest(t *testing.T) { + nettest.TestConn(t, makePipe) +} |