diff options
Diffstat (limited to 'pkg/tcpip/adapters')
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet_test.go | 110 |
2 files changed, 132 insertions, 0 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 232d44d24..628e28f57 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -435,6 +435,28 @@ func (c *Conn) Close() error { return nil } +// CloseRead shuts down the reading side of the TCP connection. Most callers +// should just use Close. +// +// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. +func (c *Conn) CloseRead() error { + if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { + return c.newOpError("close", errors.New(terr.String())) + } + return nil +} + +// CloseWrite shuts down the writing side of the TCP connection. Most callers +// should just use Close. +// +// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. +func (c *Conn) CloseWrite() error { + if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { + return c.newOpError("close", errors.New(terr.String())) + } + return nil +} + // LocalAddr implements net.Conn.LocalAddr. func (c *Conn) LocalAddr() net.Addr { a, err := c.ep.GetLocalAddress() diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index ab3da2e4e..e84f73feb 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -16,6 +16,7 @@ package gonet import ( "fmt" + "io" "net" "reflect" "strings" @@ -222,6 +223,115 @@ func TestCloseReaderWithForwarder(t *testing.T) { sender.close() } +func TestCloseRead(t *testing.T) { + s, terr := newLoopbackStack() + if terr != nil { + t.Fatalf("newLoopbackStack() = %v", terr) + } + + addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + + fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + t.Fatalf("r.CreateEndpoint() = %v", err) + } + defer ep.Close() + r.Complete(false) + + c := NewConn(&wq, ep) + + buf := make([]byte, 256) + n, e := c.Read(buf) + if e != nil || string(buf[:n]) != "abc123" { + t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) + } + + if n, e = c.Write([]byte("abc123")); e != nil { + t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) + } + }) + + s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) + + tc, terr := connect(s, addr) + if terr != nil { + t.Fatalf("connect() = %v", terr) + } + c := NewConn(tc.wq, tc.ep) + + if err := c.CloseRead(); err != nil { + t.Errorf("c.CloseRead() = %v", err) + } + + buf := make([]byte, 256) + if n, err := c.Read(buf); err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err) + } + + if n, err := c.Write([]byte("abc123")); n != 6 || err != nil { + t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err) + } +} + +func TestCloseWrite(t *testing.T) { + s, terr := newLoopbackStack() + if terr != nil { + t.Fatalf("newLoopbackStack() = %v", terr) + } + + addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + + fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + t.Fatalf("r.CreateEndpoint() = %v", err) + } + defer ep.Close() + r.Complete(false) + + c := NewConn(&wq, ep) + + n, e := c.Read(make([]byte, 256)) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e) + } + + if n, e = c.Write([]byte("abc123")); n != 6 || e != nil { + t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) + } + }) + + s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) + + tc, terr := connect(s, addr) + if terr != nil { + t.Fatalf("connect() = %v", terr) + } + c := NewConn(tc.wq, tc.ep) + + if err := c.CloseWrite(); err != nil { + t.Errorf("c.CloseWrite() = %v", err) + } + + buf := make([]byte, 256) + n, err := c.Read(buf) + if err != nil || string(buf[:n]) != "abc123" { + t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err) + } + + n, err = c.Write([]byte("abc123")) + got, ok := err.(*net.OpError) + want := "endpoint is closed for send" + if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) { + t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want) + } +} + func TestUDPForwarder(t *testing.T) { s, terr := newLoopbackStack() if terr != nil { |