diff options
-rw-r--r-- | server/sockopt_linux.go | 18 | ||||
-rw-r--r-- | server/sockopt_linux_test.go | 40 |
2 files changed, 52 insertions, 6 deletions
diff --git a/server/sockopt_linux.go b/server/sockopt_linux.go index 146f87b6..336caf15 100644 --- a/server/sockopt_linux.go +++ b/server/sockopt_linux.go @@ -177,6 +177,15 @@ func (d *TCPDialer) DialTCP(addr string, port int) (*net.TCPConn, error) { if err != nil { return nil, err } + fi := os.NewFile(uintptr(fd), "") + defer fi.Close() + // A new socket was created so we must close it before this + // function returns either on failure or success. On success, + // net.FileConn() in newTCPConn() increases the refcount of + // the socket so this fi.Close() doesn't destroy the socket. + // The caller must call Close() with the file later. + // Note that the above os.NewFile() doesn't play with the + // refcount. if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1); err != nil { return nil, os.NewSyscallError("setsockopt", err) @@ -208,10 +217,7 @@ func (d *TCPDialer) DialTCP(addr string, port int) (*net.TCPConn, error) { return nil, os.NewSyscallError("bind", err) } - newTCPConn := func(fd int) (*net.TCPConn, error) { - fi := os.NewFile(uintptr(fd), "") - defer fi.Close() - + newTCPConn := func(fi *os.File) (*net.TCPConn, error) { if conn, err := net.FileConn(fi); err != nil { return nil, err } else { @@ -224,7 +230,7 @@ func (d *TCPDialer) DialTCP(addr string, port int) (*net.TCPConn, error) { case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: // do timeout handling case nil, syscall.EISCONN: - return newTCPConn(fd) + return newTCPConn(fi) default: return nil, os.NewSyscallError("connect", err) } @@ -259,7 +265,7 @@ func (d *TCPDialer) DialTCP(addr string, port int) (*net.TCPConn, error) { switch err := syscall.Errno(nerr); err { case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: case syscall.Errno(0), syscall.EISCONN: - return newTCPConn(fd) + return newTCPConn(fi) default: return nil, os.NewSyscallError("getsockopt", err) } diff --git a/server/sockopt_linux_test.go b/server/sockopt_linux_test.go index f78431af..3730672c 100644 --- a/server/sockopt_linux_test.go +++ b/server/sockopt_linux_test.go @@ -18,8 +18,12 @@ package server import ( "bytes" + "fmt" + "net" + "os" "syscall" "testing" + "time" "unsafe" ) @@ -64,3 +68,39 @@ func Test_buildTcpMD5Sigv6(t *testing.T) { t.Error("Something wrong v6") } } + +func Test_DialTCP_FDleak(t *testing.T) { + openFds := func() int { + pid := os.Getpid() + f, err := os.OpenFile(fmt.Sprintf("/proc/%d/fdinfo", pid), os.O_RDONLY, 0) + if err != nil { + t.Fatal(err) + } + defer f.Close() + names, err := f.Readdirnames(0) + if err != nil { + t.Fatal(err) + } + return len(names) + } + + before := openFds() + + for i := 0; i < 10; i++ { + laddr, _ := net.ResolveTCPAddr("tcp", net.JoinHostPort("127.0.0.1", "0")) + d := TCPDialer{ + Dialer: net.Dialer{ + LocalAddr: laddr, + Timeout: 1 * time.Second, + }, + } + if _, err := d.DialTCP("127.0.0.1", 1); err == nil { + t.Fatalf("should not succeed") + } + + } + + if after := openFds(); before != after { + t.Fatalf("could be fd leak, %d %d", before, after) + } +} |