diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-10-30 02:39:56 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-11-04 12:53:52 +0100 |
commit | c07dd60cdb8eb3fc87b63ed0938979e4e4fb6278 (patch) | |
tree | 091349ddf5b90d2fa802c752436c158a25577e57 /ipc | |
parent | eb6302c7eb71e3e3df9f63395bc5c97dcf0efc84 (diff) |
namedpipe: rename from winpipe to keep in sync with CL299009
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'ipc')
-rw-r--r-- | ipc/namedpipe/file.go (renamed from ipc/winpipe/file.go) | 14 | ||||
-rw-r--r-- | ipc/namedpipe/namedpipe.go (renamed from ipc/winpipe/winpipe.go) | 92 | ||||
-rw-r--r-- | ipc/namedpipe/namedpipe_test.go (renamed from ipc/winpipe/winpipe_test.go) | 125 | ||||
-rw-r--r-- | ipc/uapi_windows.go | 8 |
4 files changed, 132 insertions, 107 deletions
diff --git a/ipc/winpipe/file.go b/ipc/namedpipe/file.go index 319565f..9c2481d 100644 --- a/ipc/winpipe/file.go +++ b/ipc/namedpipe/file.go @@ -1,12 +1,12 @@ -//go:build windows +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ +//go:build windows +// +build windows -package winpipe +package namedpipe import ( "io" diff --git a/ipc/winpipe/winpipe.go b/ipc/namedpipe/namedpipe.go index e3719d6..6db5ea3 100644 --- a/ipc/winpipe/winpipe.go +++ b/ipc/namedpipe/namedpipe.go @@ -1,13 +1,13 @@ -//go:build windows +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ +//go:build windows +// +build windows -// Package winpipe implements a net.Conn and net.Listener around Windows named pipes. -package winpipe +// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes. +package namedpipe import ( "context" @@ -15,6 +15,7 @@ import ( "net" "os" "runtime" + "sync/atomic" "time" "unsafe" @@ -28,7 +29,7 @@ type pipe struct { type messageBytePipe struct { pipe - writeClosed bool + writeClosed int32 readEOF bool } @@ -50,25 +51,26 @@ func (f *pipe) SetDeadline(t time.Time) error { // CloseWrite closes the write side of a message pipe in byte mode. func (f *messageBytePipe) CloseWrite() error { - if f.writeClosed { + if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) { return io.ErrClosedPipe } err := f.file.Flush() if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) return err } _, err = f.file.Write(nil) if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) return err } - f.writeClosed = true return nil } // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // they are used to implement CloseWrite. func (f *messageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { + if atomic.LoadInt32(&f.writeClosed) != 0 { return 0, io.ErrClosedPipe } if len(b) == 0 { @@ -142,30 +144,24 @@ type DialConfig struct { ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. } -// Dial connects to the specified named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. -func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(2 * time.Second) +// DialTimeout connects to the specified named pipe by path, timing out if the +// connection takes longer than the specified duration. If timeout is zero, then +// we use a default timeout of 2 seconds. +func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + if timeout == 0 { + timeout = time.Second * 2 } + absTimeout := time.Now().Add(timeout) ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialContext(ctx, path, config) + conn, err := config.DialContext(ctx, path) if err == context.DeadlineExceeded { return nil, os.ErrDeadlineExceeded } return conn, err } -// DialContext attempts to connect to the specified named pipe by path -// cancellation or timeout. -func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) { - if config == nil { - config = &DialConfig{} - } +// DialContext attempts to connect to the specified named pipe by path. +func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) { var err error var h windows.Handle h, err = tryDialPipe(ctx, &path) @@ -213,6 +209,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn return &pipe{file: f, path: path}, nil } +var defaultDialer DialConfig + +// DialTimeout calls DialConfig.DialTimeout using an empty configuration. +func DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + return defaultDialer.DialTimeout(path, timeout) +} + +// DialContext calls DialConfig.DialContext using an empty configuration. +func DialContext(ctx context.Context, path string) (net.Conn, error) { + return defaultDialer.DialContext(ctx, path) +} + type acceptResponse struct { f *file err error @@ -222,12 +230,12 @@ type pipeListener struct { firstHandle windows.Handle path string config ListenConfig - acceptCh chan (chan acceptResponse) + acceptCh chan chan acceptResponse closeCh chan int doneCh chan int } -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) { path16, err := windows.UTF16PtrFromString(path) if err != nil { return 0, &os.PathError{Op: "open", Path: path, Err: err} @@ -247,7 +255,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste oa.ObjectName = &ntPath // The security descriptor is only needed for the first pipe. - if first { + if isFirstPipe { if sd != nil { oa.SecurityDescriptor = sd } else { @@ -257,7 +265,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste return 0, err } defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) - sd, err := windows.NewSecurityDescriptor() + sd, err = windows.NewSecurityDescriptor() if err != nil { return 0, err } @@ -275,11 +283,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste disposition := uint32(windows.FILE_OPEN) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { + if isFirstPipe { disposition = windows.FILE_CREATE // By not asking for read or write access, the named pipe file system // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. + // client connections until the next call with isFirstPipe == false. access = windows.SYNCHRONIZE } @@ -395,10 +403,7 @@ type ListenConfig struct { // Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. // The pipe must not already exist. -func Listen(path string, c *ListenConfig) (net.Listener, error) { - if c == nil { - c = &ListenConfig{} - } +func (c *ListenConfig) Listen(path string) (net.Listener, error) { h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) if err != nil { return nil, err @@ -407,12 +412,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) { firstHandle: h, path: path, config: *c, - acceptCh: make(chan (chan acceptResponse)), + acceptCh: make(chan chan acceptResponse), closeCh: make(chan int), doneCh: make(chan int), } // The first connection is swallowed on Windows 7 & 8, so synthesize it. - if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) { path16, err := windows.UTF16PtrFromString(path) if err == nil { h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) @@ -425,6 +430,13 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) { return l, nil } +var defaultListener ListenConfig + +// Listen calls ListenConfig.Listen using an empty configuration. +func Listen(path string) (net.Listener, error) { + return defaultListener.Listen(path) +} + func connectPipe(p *file) error { c, err := p.prepareIo() if err != nil { diff --git a/ipc/winpipe/winpipe_test.go b/ipc/namedpipe/namedpipe_test.go index ea515e3..0573d0f 100644 --- a/ipc/winpipe/winpipe_test.go +++ b/ipc/namedpipe/namedpipe_test.go @@ -1,12 +1,12 @@ -//go:build windows +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ +//go:build windows +// +build windows -package winpipe_test +package namedpipe_test import ( "bufio" @@ -22,7 +22,7 @@ import ( "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/winpipe" + "golang.zx2c4.com/wireguard/ipc/namedpipe" ) func randomPipePath() string { @@ -30,7 +30,7 @@ func randomPipePath() string { if err != nil { panic(err) } - return `\\.\PIPE\go-winpipe-test-` + guid.String() + return `\\.\PIPE\go-namedpipe-test-` + guid.String() } func TestPingPong(t *testing.T) { @@ -39,7 +39,7 @@ func TestPingPong(t *testing.T) { pong = 24 ) pipePath := randomPipePath() - listener, err := winpipe.Listen(pipePath, nil) + listener, err := namedpipe.Listen(pipePath) if err != nil { t.Fatalf("unable to listen on pipe: %v", err) } @@ -64,11 +64,12 @@ func TestPingPong(t *testing.T) { t.Fatalf("unable to write pong to pipe: %v", err) } }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatalf("unable to dial pipe: %v", err) } defer client.Close() + client.SetDeadline(time.Now().Add(time.Second * 5)) var data [1]byte data[0] = ping _, err = client.Write(data[:]) @@ -85,7 +86,7 @@ func TestPingPong(t *testing.T) { } func TestDialUnknownFailsImmediately(t *testing.T) { - _, err := winpipe.Dial(randomPipePath(), nil, nil) + _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0)) if !errors.Is(err, syscall.ENOENT) { t.Fatalf("expected ENOENT got %v", err) } @@ -93,13 +94,15 @@ func TestDialUnknownFailsImmediately(t *testing.T) { func TestDialListenerTimesOut(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() - d := 10 * time.Millisecond - _, err = winpipe.Dial(pipePath, &d, nil) + pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond) + if err == nil { + pipe.Close() + } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } @@ -107,14 +110,17 @@ func TestDialListenerTimesOut(t *testing.T) { func TestDialContextListenerTimesOut(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() d := 10 * time.Millisecond ctx, _ := context.WithTimeout(context.Background(), d) - _, err = winpipe.DialContext(ctx, pipePath, nil) + pipe, err := namedpipe.DialContext(ctx, pipePath) + if err == nil { + pipe.Close() + } if err != context.DeadlineExceeded { t.Fatalf("expected context.DeadlineExceeded, got %v", err) } @@ -123,14 +129,14 @@ func TestDialContextListenerTimesOut(t *testing.T) { func TestDialListenerGetsCancelled(t *testing.T) { pipePath := randomPipePath() ctx, cancel := context.WithCancel(context.Background()) - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } - ch := make(chan error) defer l.Close() + ch := make(chan error) go func(ctx context.Context, ch chan error) { - _, err := winpipe.DialContext(ctx, pipePath, nil) + _, err := namedpipe.DialContext(ctx, pipePath) ch <- err }(ctx, ch) time.Sleep(time.Millisecond * 30) @@ -147,23 +153,28 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { } pipePath := randomPipePath() sd, _ := windows.SecurityDescriptorFromString("D:") - c := winpipe.ListenConfig{ + l, err := (&namedpipe.ListenConfig{ SecurityDescriptor: sd, - } - l, err := winpipe.Listen(pipePath, &c) + }).Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) + pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err == nil { + pipe.Close() + } if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) } } -func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { +func getConnection(cfg *namedpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, cfg) + if cfg == nil { + cfg = &namedpipe.ListenConfig{} + } + l, err := cfg.Listen(pipePath) if err != nil { return } @@ -179,7 +190,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, ch <- response{c, err} }() - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { return } @@ -236,7 +247,7 @@ func server(l net.Listener, ch chan int) { func TestFullListenDialReadWrite(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -245,7 +256,7 @@ func TestFullListenDialReadWrite(t *testing.T) { ch := make(chan int) go server(l, ch) - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -275,7 +286,7 @@ func TestFullListenDialReadWrite(t *testing.T) { func TestCloseAbortsListen(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -328,7 +339,7 @@ func TestCloseServerEOFClient(t *testing.T) { } func TestCloseWriteEOF(t *testing.T) { - cfg := &winpipe.ListenConfig{ + cfg := &namedpipe.ListenConfig{ MessageMode: true, } c, s, err := getConnection(cfg) @@ -356,7 +367,7 @@ func TestCloseWriteEOF(t *testing.T) { func TestAcceptAfterCloseFails(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -369,12 +380,15 @@ func TestAcceptAfterCloseFails(t *testing.T) { func TestDialTimesOutByDefault(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) + pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. + if err == nil { + pipe.Close() + } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } @@ -382,7 +396,7 @@ func TestDialTimesOutByDefault(t *testing.T) { func TestTimeoutPendingRead(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -400,7 +414,7 @@ func TestTimeoutPendingRead(t *testing.T) { close(serverDone) }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -430,7 +444,7 @@ func TestTimeoutPendingRead(t *testing.T) { func TestTimeoutPendingWrite(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -448,7 +462,7 @@ func TestTimeoutPendingWrite(t *testing.T) { close(serverDone) }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -480,13 +494,12 @@ type CloseWriter interface { } func TestEchoWithMessaging(t *testing.T) { - c := winpipe.ListenConfig{ + pipePath := randomPipePath() + l, err := (&namedpipe.ListenConfig{ MessageMode: true, // Use message mode so that CloseWrite() is supported InputBufferSize: 65536, // Use 64KB buffers to improve performance OutputBufferSize: 65536, - } - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, &c) + }).Listen(pipePath) if err != nil { t.Fatal(err) } @@ -496,19 +509,21 @@ func TestEchoWithMessaging(t *testing.T) { clientDone := make(chan bool) go func() { // server echo - conn, e := l.Accept() - if e != nil { - t.Fatal(e) + conn, err := l.Accept() + if err != nil { + t.Fatal(err) } defer conn.Close() time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent - io.Copy(conn, conn) + _, err = io.Copy(conn, conn) + if err != nil { + t.Fatal(err) + } conn.(CloseWriter).CloseWrite() close(listenerDone) }() - timeout := 1 * time.Second - client, err := winpipe.Dial(pipePath, &timeout, nil) + client, err := namedpipe.DialTimeout(pipePath, time.Second) if err != nil { t.Fatal(err) } @@ -521,7 +536,7 @@ func TestEchoWithMessaging(t *testing.T) { if e != nil { t.Fatal(e) } - if n != 2 { + if n != 2 || bytes[0] != 0 || bytes[1] != 1 { t.Fatalf("expected 2 bytes, got %v", n) } close(clientDone) @@ -545,7 +560,7 @@ func TestEchoWithMessaging(t *testing.T) { func TestConnectRace(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -565,7 +580,7 @@ func TestConnectRace(t *testing.T) { }() for i := 0; i < 1000; i++ { - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -580,7 +595,7 @@ func TestMessageReadMode(t *testing.T) { var wg sync.WaitGroup defer wg.Wait() pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true}) + l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath) if err != nil { t.Fatal(err) } @@ -602,7 +617,7 @@ func TestMessageReadMode(t *testing.T) { s.Close() }() - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -643,13 +658,13 @@ func TestListenConnectRace(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err == nil { c.Close() } wg.Done() }() - s, err := winpipe.Listen(pipePath, nil) + s, err := namedpipe.Listen(pipePath) if err != nil { t.Error(i, err) } else { diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index a4d68da..a1bfbd1 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -9,8 +9,7 @@ import ( "net" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/ipc/winpipe" + "golang.zx2c4.com/wireguard/ipc/namedpipe" ) // TODO: replace these with actual standard windows error numbers from the win package @@ -61,10 +60,9 @@ func init() { } func UAPIListen(name string) (net.Listener, error) { - config := winpipe.ListenConfig{ + listener, err := (&namedpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, - } - listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) + }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name) if err != nil { return nil, err } |