diff options
Diffstat (limited to 'pkg/tcpip/adapters/gonet')
-rw-r--r-- | pkg/tcpip/adapters/gonet/BUILD | 8 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 127 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet_test.go | 95 |
3 files changed, 135 insertions, 95 deletions
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 78df5a0b1..e57d45f2a 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,14 +1,13 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "gonet", srcs = ["gonet.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", @@ -22,7 +21,8 @@ go_test( name = "gonet_test", size = "small", srcs = ["gonet_test.go"], - embed = [":gonet"], + library = ":gonet", + tags = ["flaky"], deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index cd6ce930a..6e0db2741 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -20,9 +20,9 @@ import ( "errors" "io" "net" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -43,18 +43,28 @@ func (e *timeoutError) Error() string { return "i/o timeout" } func (e *timeoutError) Timeout() bool { return true } func (e *timeoutError) Temporary() bool { return true } -// A Listener is a wrapper around a tcpip endpoint that implements +// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements // net.Listener. -type Listener struct { +type TCPListener struct { stack *stack.Stack ep tcpip.Endpoint wq *waiter.Queue cancel chan struct{} } -// NewListener creates a new Listener. -func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) { - // Create TCP endpoint, bind it, then start listening. +// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint. +func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener { + return &TCPListener{ + stack: s, + ep: ep, + wq: wq, + cancel: make(chan struct{}), + } +} + +// ListenTCP creates a new TCPListener. +func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) { + // Create a TCP endpoint, bind it, then start listening. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) if err != nil { @@ -81,28 +91,23 @@ func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkPr } } - return &Listener{ - stack: s, - ep: ep, - wq: &wq, - cancel: make(chan struct{}), - }, nil + return NewTCPListener(s, &wq, ep), nil } // Close implements net.Listener.Close. -func (l *Listener) Close() error { +func (l *TCPListener) Close() error { l.ep.Close() return nil } // Shutdown stops the HTTP server. -func (l *Listener) Shutdown() { +func (l *TCPListener) Shutdown() { l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) close(l.cancel) // broadcast cancellation } // Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { +func (l *TCPListener) Addr() net.Addr { a, err := l.ep.GetLocalAddress() if err != nil { return nil @@ -208,9 +213,9 @@ func (d *deadlineTimer) SetDeadline(t time.Time) error { return nil } -// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn +// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn // interface. -type Conn struct { +type TCPConn struct { deadlineTimer wq *waiter.Queue @@ -228,9 +233,9 @@ type Conn struct { read buffer.View } -// NewConn creates a new Conn. -func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { - c := &Conn{ +// NewTCPConn creates a new TCPConn. +func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { + c := &TCPConn{ wq: wq, ep: ep, } @@ -239,7 +244,7 @@ func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { } // Accept implements net.Conn.Accept. -func (l *Listener) Accept() (net.Conn, error) { +func (l *TCPListener) Accept() (net.Conn, error) { n, wq, err := l.ep.Accept() if err == tcpip.ErrWouldBlock { @@ -272,7 +277,7 @@ func (l *Listener) Accept() (net.Conn, error) { } } - return NewConn(wq, n), nil + return NewTCPConn(wq, n), nil } type opErrorer interface { @@ -323,7 +328,7 @@ func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, a } // Read implements net.Conn.Read. -func (c *Conn) Read(b []byte) (int, error) { +func (c *TCPConn) Read(b []byte) (int, error) { c.readMu.Lock() defer c.readMu.Unlock() @@ -352,7 +357,7 @@ func (c *Conn) Read(b []byte) (int, error) { } // Write implements net.Conn.Write. -func (c *Conn) Write(b []byte) (int, error) { +func (c *TCPConn) Write(b []byte) (int, error) { deadline := c.writeCancel() // Check if deadlineTimer has already expired. @@ -431,7 +436,7 @@ func (c *Conn) Write(b []byte) (int, error) { } // Close implements net.Conn.Close. -func (c *Conn) Close() error { +func (c *TCPConn) Close() error { c.ep.Close() return nil } @@ -440,7 +445,7 @@ func (c *Conn) Close() error { // should just use Close. // // A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. -func (c *Conn) CloseRead() error { +func (c *TCPConn) CloseRead() error { if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { return c.newOpError("close", errors.New(terr.String())) } @@ -451,7 +456,7 @@ func (c *Conn) CloseRead() error { // should just use Close. // // A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. -func (c *Conn) CloseWrite() error { +func (c *TCPConn) CloseWrite() error { if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { return c.newOpError("close", errors.New(terr.String())) } @@ -459,7 +464,7 @@ func (c *Conn) CloseWrite() error { } // LocalAddr implements net.Conn.LocalAddr. -func (c *Conn) LocalAddr() net.Addr { +func (c *TCPConn) LocalAddr() net.Addr { a, err := c.ep.GetLocalAddress() if err != nil { return nil @@ -468,7 +473,7 @@ func (c *Conn) LocalAddr() net.Addr { } // RemoteAddr implements net.Conn.RemoteAddr. -func (c *Conn) RemoteAddr() net.Addr { +func (c *TCPConn) RemoteAddr() net.Addr { a, err := c.ep.GetRemoteAddress() if err != nil { return nil @@ -476,7 +481,7 @@ func (c *Conn) RemoteAddr() net.Addr { return fullToTCPAddr(a) } -func (c *Conn) newOpError(op string, err error) *net.OpError { +func (c *TCPConn) newOpError(op string, err error) *net.OpError { return &net.OpError{ Op: op, Net: "tcp", @@ -494,14 +499,14 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} } -// DialTCP creates a new TCP Conn connected to the specified address. -func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { +// DialTCP creates a new TCPConn connected to the specified address. +func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { return DialContextTCP(context.Background(), s, addr, network) } -// DialContextTCP creates a new TCP Conn connected to the specified address +// DialContextTCP creates a new TCPConn connected to the specified address // with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { // Create TCP endpoint, then connect. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) @@ -543,12 +548,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, } } - return NewConn(&wq, ep), nil + return NewTCPConn(&wq, ep), nil } -// A PacketConn is a wrapper around a tcpip endpoint that implements -// net.PacketConn. -type PacketConn struct { +// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements +// net.Conn and net.PacketConn. +type UDPConn struct { deadlineTimer stack *stack.Stack @@ -556,12 +561,23 @@ type PacketConn struct { wq *waiter.Queue } -// DialUDP creates a new PacketConn. +// NewUDPConn creates a new UDPConn. +func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn { + c := &UDPConn{ + stack: s, + ep: ep, + wq: wq, + } + c.deadlineTimer.init() + return c +} + +// DialUDP creates a new UDPConn. // // If laddr is nil, a local address is automatically chosen. // -// If raddr is nil, the PacketConn is left unconnected. -func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { +// If raddr is nil, the UDPConn is left unconnected. +func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) { var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) if err != nil { @@ -580,12 +596,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw } } - c := PacketConn{ - stack: s, - ep: ep, - wq: &wq, - } - c.deadlineTimer.init() + c := NewUDPConn(s, &wq, ep) if raddr != nil { if err := c.ep.Connect(*raddr); err != nil { @@ -599,14 +610,14 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw } } - return &c, nil + return c, nil } -func (c *PacketConn) newOpError(op string, err error) *net.OpError { +func (c *UDPConn) newOpError(op string, err error) *net.OpError { return c.newRemoteOpError(op, nil, err) } -func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { +func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { return &net.OpError{ Op: op, Net: "udp", @@ -617,22 +628,22 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne } // RemoteAddr implements net.Conn.RemoteAddr. -func (c *PacketConn) RemoteAddr() net.Addr { +func (c *UDPConn) RemoteAddr() net.Addr { a, err := c.ep.GetRemoteAddress() if err != nil { return nil } - return fullToTCPAddr(a) + return fullToUDPAddr(a) } // Read implements net.Conn.Read -func (c *PacketConn) Read(b []byte) (int, error) { +func (c *UDPConn) Read(b []byte) (int, error) { bytesRead, _, err := c.ReadFrom(b) return bytesRead, err } // ReadFrom implements net.PacketConn.ReadFrom. -func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { +func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { deadline := c.readCancel() var addr tcpip.FullAddress @@ -644,12 +655,12 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { return copy(b, read), fullToUDPAddr(addr), nil } -func (c *PacketConn) Write(b []byte) (int, error) { +func (c *UDPConn) Write(b []byte) (int, error) { return c.WriteTo(b, nil) } // WriteTo implements net.PacketConn.WriteTo. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { deadline := c.writeCancel() // Check if deadline has already expired. @@ -707,13 +718,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // Close implements net.PacketConn.Close. -func (c *PacketConn) Close() error { +func (c *UDPConn) Close() error { c.ep.Close() return nil } // LocalAddr implements net.PacketConn.LocalAddr. -func (c *PacketConn) LocalAddr() net.Addr { +func (c *UDPConn) LocalAddr() net.Addr { a, err := c.ep.GetLocalAddress() if err != nil { return nil diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 8ced960bb..3c552988a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -41,7 +41,7 @@ const ( ) func TestTimeouts(t *testing.T) { - nc := NewConn(nil, nil) + nc := NewTCPConn(nil, nil) dlfs := []struct { name string f func(time.Time) error @@ -127,12 +127,16 @@ func TestCloseReader(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - l, e := NewListener(s, addr, ipv4.ProtocolNumber) + l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { t.Fatalf("NewListener() = %v", e) } @@ -151,10 +155,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -170,13 +172,17 @@ func TestCloseReader(t *testing.T) { sender.close() } -// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when +// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when // using tcp.Forwarder. func TestCloseReaderWithForwarder(t *testing.T) { s, err := newLoopbackStack() if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -194,7 +200,7 @@ func TestCloseReaderWithForwarder(t *testing.T) { defer ep.Close() r.Complete(false) - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) // Give c.Read() a chance to block before closing the connection. time.AfterFunc(time.Millisecond*50, func() { @@ -203,10 +209,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -229,30 +233,21 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + _, err := r.CreateEndpoint(&wq) if err != nil { t.Fatalf("r.CreateEndpoint() = %v", err) } - defer ep.Close() - r.Complete(false) - - c := NewConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } + // Endpoint will be closed in deferred s.Close (above). }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -261,7 +256,7 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("connect() = %v", terr) } - c := NewConn(tc.wq, tc.ep) + c := NewTCPConn(tc.wq, tc.ep) if err := c.CloseRead(); err != nil { t.Errorf("c.CloseRead() = %v", err) @@ -282,6 +277,10 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -295,7 +294,7 @@ func TestCloseWrite(t *testing.T) { defer ep.Close() r.Complete(false) - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) n, e := c.Read(make([]byte, 256)) if n != 0 || e != io.EOF { @@ -313,7 +312,7 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("connect() = %v", terr) } - c := NewConn(tc.wq, tc.ep) + c := NewTCPConn(tc.wq, tc.ep) if err := c.CloseWrite(); err != nil { t.Errorf("c.CloseWrite() = %v", err) @@ -338,6 +337,10 @@ func TestUDPForwarder(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -357,7 +360,7 @@ func TestUDPForwarder(t *testing.T) { } defer ep.Close() - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) buf := make([]byte, 256) n, e := c.Read(buf) @@ -395,12 +398,16 @@ func TestDeadlineChange(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - l, e := NewListener(s, addr, ipv4.ProtocolNumber) + l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { t.Fatalf("NewListener() = %v", e) } @@ -444,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -496,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -545,7 +560,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { addr := tcpip.FullAddress{NICID, ip, 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - l, err := NewListener(s, addr, ipv4.ProtocolNumber) + l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) if err != nil { return nil, nil, nil, fmt.Errorf("NewListener: %v", err) } @@ -566,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { stop = func() { c1.Close() c2.Close() + s.Close() + s.Wait() } if err := l.Close(); err != nil { @@ -628,6 +645,10 @@ func TestTCPDialError(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -645,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -663,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) |