summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/adapters/gonet
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/adapters/gonet')
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go111
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go83
2 files changed, 116 insertions, 78 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 711969b9b..6e0db2741 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -43,18 +43,28 @@ func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
-// A Listener is a wrapper around a tcpip endpoint that implements
+// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements
// net.Listener.
-type Listener struct {
+type TCPListener struct {
stack *stack.Stack
ep tcpip.Endpoint
wq *waiter.Queue
cancel chan struct{}
}
-// NewListener creates a new Listener.
-func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
- // Create TCP endpoint, bind it, then start listening.
+// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint.
+func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener {
+ return &TCPListener{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ cancel: make(chan struct{}),
+ }
+}
+
+// ListenTCP creates a new TCPListener.
+func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) {
+ // Create a TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
@@ -81,28 +91,23 @@ func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkPr
}
}
- return &Listener{
- stack: s,
- ep: ep,
- wq: &wq,
- cancel: make(chan struct{}),
- }, nil
+ return NewTCPListener(s, &wq, ep), nil
}
// Close implements net.Listener.Close.
-func (l *Listener) Close() error {
+func (l *TCPListener) Close() error {
l.ep.Close()
return nil
}
// Shutdown stops the HTTP server.
-func (l *Listener) Shutdown() {
+func (l *TCPListener) Shutdown() {
l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
close(l.cancel) // broadcast cancellation
}
// Addr implements net.Listener.Addr.
-func (l *Listener) Addr() net.Addr {
+func (l *TCPListener) Addr() net.Addr {
a, err := l.ep.GetLocalAddress()
if err != nil {
return nil
@@ -208,9 +213,9 @@ func (d *deadlineTimer) SetDeadline(t time.Time) error {
return nil
}
-// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
+// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn
// interface.
-type Conn struct {
+type TCPConn struct {
deadlineTimer
wq *waiter.Queue
@@ -228,9 +233,9 @@ type Conn struct {
read buffer.View
}
-// NewConn creates a new Conn.
-func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
- c := &Conn{
+// NewTCPConn creates a new TCPConn.
+func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
+ c := &TCPConn{
wq: wq,
ep: ep,
}
@@ -239,7 +244,7 @@ func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
}
// Accept implements net.Conn.Accept.
-func (l *Listener) Accept() (net.Conn, error) {
+func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept()
if err == tcpip.ErrWouldBlock {
@@ -272,7 +277,7 @@ func (l *Listener) Accept() (net.Conn, error) {
}
}
- return NewConn(wq, n), nil
+ return NewTCPConn(wq, n), nil
}
type opErrorer interface {
@@ -323,7 +328,7 @@ func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, a
}
// Read implements net.Conn.Read.
-func (c *Conn) Read(b []byte) (int, error) {
+func (c *TCPConn) Read(b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
@@ -352,7 +357,7 @@ func (c *Conn) Read(b []byte) (int, error) {
}
// Write implements net.Conn.Write.
-func (c *Conn) Write(b []byte) (int, error) {
+func (c *TCPConn) Write(b []byte) (int, error) {
deadline := c.writeCancel()
// Check if deadlineTimer has already expired.
@@ -431,7 +436,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}
// Close implements net.Conn.Close.
-func (c *Conn) Close() error {
+func (c *TCPConn) Close() error {
c.ep.Close()
return nil
}
@@ -440,7 +445,7 @@ func (c *Conn) Close() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
-func (c *Conn) CloseRead() error {
+func (c *TCPConn) CloseRead() error {
if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -451,7 +456,7 @@ func (c *Conn) CloseRead() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
-func (c *Conn) CloseWrite() error {
+func (c *TCPConn) CloseWrite() error {
if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -459,7 +464,7 @@ func (c *Conn) CloseWrite() error {
}
// LocalAddr implements net.Conn.LocalAddr.
-func (c *Conn) LocalAddr() net.Addr {
+func (c *TCPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
@@ -468,7 +473,7 @@ func (c *Conn) LocalAddr() net.Addr {
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *Conn) RemoteAddr() net.Addr {
+func (c *TCPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -476,7 +481,7 @@ func (c *Conn) RemoteAddr() net.Addr {
return fullToTCPAddr(a)
}
-func (c *Conn) newOpError(op string, err error) *net.OpError {
+func (c *TCPConn) newOpError(op string, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "tcp",
@@ -494,14 +499,14 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
}
-// DialTCP creates a new TCP Conn connected to the specified address.
-func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+// DialTCP creates a new TCPConn connected to the specified address.
+func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
return DialContextTCP(context.Background(), s, addr, network)
}
-// DialContextTCP creates a new TCP Conn connected to the specified address
+// DialContextTCP creates a new TCPConn connected to the specified address
// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -543,12 +548,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
}
}
- return NewConn(&wq, ep), nil
+ return NewTCPConn(&wq, ep), nil
}
-// A PacketConn is a wrapper around a tcpip endpoint that implements
-// net.PacketConn.
-type PacketConn struct {
+// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
+// net.Conn and net.PacketConn.
+type UDPConn struct {
deadlineTimer
stack *stack.Stack
@@ -556,9 +561,9 @@ type PacketConn struct {
wq *waiter.Queue
}
-// NewPacketConn creates a new PacketConn.
-func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn {
- c := &PacketConn{
+// NewUDPConn creates a new UDPConn.
+func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn {
+ c := &UDPConn{
stack: s,
ep: ep,
wq: wq,
@@ -567,12 +572,12 @@ func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketC
return c
}
-// DialUDP creates a new PacketConn.
+// DialUDP creates a new UDPConn.
//
// If laddr is nil, a local address is automatically chosen.
//
-// If raddr is nil, the PacketConn is left unconnected.
-func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
+// If raddr is nil, the UDPConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
@@ -591,7 +596,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- c := NewPacketConn(s, &wq, ep)
+ c := NewUDPConn(s, &wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
@@ -608,11 +613,11 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
return c, nil
}
-func (c *PacketConn) newOpError(op string, err error) *net.OpError {
+func (c *UDPConn) newOpError(op string, err error) *net.OpError {
return c.newRemoteOpError(op, nil, err)
}
-func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
+func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "udp",
@@ -623,7 +628,7 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *PacketConn) RemoteAddr() net.Addr {
+func (c *UDPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -632,13 +637,13 @@ func (c *PacketConn) RemoteAddr() net.Addr {
}
// Read implements net.Conn.Read
-func (c *PacketConn) Read(b []byte) (int, error) {
+func (c *UDPConn) Read(b []byte) (int, error) {
bytesRead, _, err := c.ReadFrom(b)
return bytesRead, err
}
// ReadFrom implements net.PacketConn.ReadFrom.
-func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
+func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
@@ -650,12 +655,12 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
return copy(b, read), fullToUDPAddr(addr), nil
}
-func (c *PacketConn) Write(b []byte) (int, error) {
+func (c *UDPConn) Write(b []byte) (int, error) {
return c.WriteTo(b, nil)
}
// WriteTo implements net.PacketConn.WriteTo.
-func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
deadline := c.writeCancel()
// Check if deadline has already expired.
@@ -713,13 +718,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// Close implements net.PacketConn.Close.
-func (c *PacketConn) Close() error {
+func (c *UDPConn) Close() error {
c.ep.Close()
return nil
}
// LocalAddr implements net.PacketConn.LocalAddr.
-func (c *PacketConn) LocalAddr() net.Addr {
+func (c *UDPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ee077ae83..3c552988a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -41,7 +41,7 @@ const (
)
func TestTimeouts(t *testing.T) {
- nc := NewConn(nil, nil)
+ nc := NewTCPConn(nil, nil)
dlfs := []struct {
name string
f func(time.Time) error
@@ -127,12 +127,16 @@ func TestCloseReader(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -168,13 +172,17 @@ func TestCloseReader(t *testing.T) {
sender.close()
}
-// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when
+// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when
// using tcp.Forwarder.
func TestCloseReaderWithForwarder(t *testing.T) {
s, err := newLoopbackStack()
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -192,7 +200,7 @@ func TestCloseReaderWithForwarder(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
// Give c.Read() a chance to block before closing the connection.
time.AfterFunc(time.Millisecond*50, func() {
@@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
+ _, err := r.CreateEndpoint(&wq)
if err != nil {
t.Fatalf("r.CreateEndpoint() = %v", err)
}
- defer ep.Close()
- r.Complete(false)
-
- c := NewConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
+ // Endpoint will be closed in deferred s.Close (above).
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -257,7 +256,7 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseRead(); err != nil {
t.Errorf("c.CloseRead() = %v", err)
@@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -291,7 +294,7 @@ func TestCloseWrite(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
n, e := c.Read(make([]byte, 256))
if n != 0 || e != io.EOF {
@@ -309,7 +312,7 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseWrite(); err != nil {
t.Errorf("c.CloseWrite() = %v", err)
@@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -353,7 +360,7 @@ func TestUDPForwarder(t *testing.T) {
}
defer ep.Close()
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
buf := make([]byte, 256)
n, e := c.Read(buf)
@@ -391,12 +398,16 @@ func TestDeadlineChange(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -541,7 +560,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
addr := tcpip.FullAddress{NICID, ip, 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
- l, err := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
}
@@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
stop = func() {
c1.Close()
c2.Close()
+ s.Close()
+ s.Wait()
}
if err := l.Close(); err != nil {
@@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)