summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorBen Burkert <ben@benburkert.com>2019-04-26 22:45:45 -0700
committerShentubot <shentubot@google.com>2019-04-26 22:46:45 -0700
commit66bca6fc221393c9553cbaa0486e07c8124e2477 (patch)
treef16003e77764b22369af09479db0a51234b9f68b /pkg
parent43dff57b878edb5502daf486cbc13b058780dd56 (diff)
tcpip/adapters/gonet: add CloseRead & CloseWrite methods to Conn
Add the CloseRead & CloseWrite methods that performs shutdown on the corresponding Read & Write sides of a connection. Change-Id: I3996a2abdc7cd68a2becba44dc4bd9f0919d2ce1 PiperOrigin-RevId: 245537950
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go22
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go110
2 files changed, 132 insertions, 0 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 232d44d24..628e28f57 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -435,6 +435,28 @@ func (c *Conn) Close() error {
return nil
}
+// CloseRead shuts down the reading side of the TCP connection. Most callers
+// should just use Close.
+//
+// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
+func (c *Conn) CloseRead() error {
+ if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
+ return c.newOpError("close", errors.New(terr.String()))
+ }
+ return nil
+}
+
+// CloseWrite shuts down the writing side of the TCP connection. Most callers
+// should just use Close.
+//
+// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
+func (c *Conn) CloseWrite() error {
+ if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
+ return c.newOpError("close", errors.New(terr.String()))
+ }
+ return nil
+}
+
// LocalAddr implements net.Conn.LocalAddr.
func (c *Conn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ab3da2e4e..e84f73feb 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -16,6 +16,7 @@ package gonet
import (
"fmt"
+ "io"
"net"
"reflect"
"strings"
@@ -222,6 +223,115 @@ func TestCloseReaderWithForwarder(t *testing.T) {
sender.close()
}
+func TestCloseRead(t *testing.T) {
+ s, terr := newLoopbackStack()
+ if terr != nil {
+ t.Fatalf("newLoopbackStack() = %v", terr)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ var wq waiter.Queue
+ ep, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ defer ep.Close()
+ r.Complete(false)
+
+ c := NewConn(&wq, ep)
+
+ buf := make([]byte, 256)
+ n, e := c.Read(buf)
+ if e != nil || string(buf[:n]) != "abc123" {
+ t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
+ }
+
+ if n, e = c.Write([]byte("abc123")); e != nil {
+ t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
+ }
+ })
+
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ tc, terr := connect(s, addr)
+ if terr != nil {
+ t.Fatalf("connect() = %v", terr)
+ }
+ c := NewConn(tc.wq, tc.ep)
+
+ if err := c.CloseRead(); err != nil {
+ t.Errorf("c.CloseRead() = %v", err)
+ }
+
+ buf := make([]byte, 256)
+ if n, err := c.Read(buf); err != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err)
+ }
+
+ if n, err := c.Write([]byte("abc123")); n != 6 || err != nil {
+ t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err)
+ }
+}
+
+func TestCloseWrite(t *testing.T) {
+ s, terr := newLoopbackStack()
+ if terr != nil {
+ t.Fatalf("newLoopbackStack() = %v", terr)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ var wq waiter.Queue
+ ep, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ defer ep.Close()
+ r.Complete(false)
+
+ c := NewConn(&wq, ep)
+
+ n, e := c.Read(make([]byte, 256))
+ if n != 0 || e != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e)
+ }
+
+ if n, e = c.Write([]byte("abc123")); n != 6 || e != nil {
+ t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
+ }
+ })
+
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ tc, terr := connect(s, addr)
+ if terr != nil {
+ t.Fatalf("connect() = %v", terr)
+ }
+ c := NewConn(tc.wq, tc.ep)
+
+ if err := c.CloseWrite(); err != nil {
+ t.Errorf("c.CloseWrite() = %v", err)
+ }
+
+ buf := make([]byte, 256)
+ n, err := c.Read(buf)
+ if err != nil || string(buf[:n]) != "abc123" {
+ t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err)
+ }
+
+ n, err = c.Write([]byte("abc123"))
+ got, ok := err.(*net.OpError)
+ want := "endpoint is closed for send"
+ if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) {
+ t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want)
+ }
+}
+
func TestUDPForwarder(t *testing.T) {
s, terr := newLoopbackStack()
if terr != nil {