diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-10-30 02:39:56 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-10-30 02:39:56 +0200 |
commit | 52704c4b928889f88b1c8effcd02788000e2a780 (patch) | |
tree | 289c53916da970cba98ca73b2e91907207173ab8 | |
parent | eb6302c7eb71e3e3df9f63395bc5c97dcf0efc84 (diff) |
winpipe: update with latest changes from CL299009
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | ipc/winpipe/file.go | 6 | ||||
-rw-r--r-- | ipc/winpipe/winpipe.go | 70 | ||||
-rw-r--r-- | ipc/winpipe/winpipe_test.go | 68 |
3 files changed, 83 insertions, 61 deletions
diff --git a/ipc/winpipe/file.go b/ipc/winpipe/file.go index 319565f..b48d47a 100644 --- a/ipc/winpipe/file.go +++ b/ipc/winpipe/file.go @@ -1,11 +1,11 @@ -//go:build windows - /* SPDX-License-Identifier: MIT * - * Copyright (C) 2005 Microsoft * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2015 Microsoft */ +//go:build windows + package winpipe import ( diff --git a/ipc/winpipe/winpipe.go b/ipc/winpipe/winpipe.go index e3719d6..d34ba82 100644 --- a/ipc/winpipe/winpipe.go +++ b/ipc/winpipe/winpipe.go @@ -1,11 +1,11 @@ -//go:build windows - /* SPDX-License-Identifier: MIT * - * Copyright (C) 2005 Microsoft * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2015 Microsoft */ +//go:build windows + // Package winpipe implements a net.Conn and net.Listener around Windows named pipes. package winpipe @@ -15,6 +15,7 @@ import ( "net" "os" "runtime" + "sync/atomic" "time" "unsafe" @@ -28,7 +29,7 @@ type pipe struct { type messageBytePipe struct { pipe - writeClosed bool + writeClosed int32 readEOF bool } @@ -50,25 +51,26 @@ func (f *pipe) SetDeadline(t time.Time) error { // CloseWrite closes the write side of a message pipe in byte mode. func (f *messageBytePipe) CloseWrite() error { - if f.writeClosed { + if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) { return io.ErrClosedPipe } err := f.file.Flush() if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) return err } _, err = f.file.Write(nil) if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) return err } - f.writeClosed = true return nil } // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // they are used to implement CloseWrite. func (f *messageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { + if atomic.LoadInt32(&f.writeClosed) != 0 { return 0, io.ErrClosedPipe } if len(b) == 0 { @@ -142,30 +144,24 @@ type DialConfig struct { ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. } -// Dial connects to the specified named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. -func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(2 * time.Second) +// DialTimeout connects to the specified named pipe by path, timing out if the +// connection takes longer than the specified duration. If timeout is zero, then +// we use a default timeout of 2 seconds. +func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + if timeout == 0 { + timeout = time.Second * 2 } + absTimeout := time.Now().Add(timeout) ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialContext(ctx, path, config) + conn, err := config.DialContext(ctx, path) if err == context.DeadlineExceeded { return nil, os.ErrDeadlineExceeded } return conn, err } -// DialContext attempts to connect to the specified named pipe by path -// cancellation or timeout. -func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) { - if config == nil { - config = &DialConfig{} - } +// DialContext attempts to connect to the specified named pipe by path. +func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) { var err error var h windows.Handle h, err = tryDialPipe(ctx, &path) @@ -213,6 +209,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn return &pipe{file: f, path: path}, nil } +var defaultDialer DialConfig + +// DialTimeout calls DialConfig.DialTimeout using an empty configuration. +func DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + return defaultDialer.DialTimeout(path, timeout) +} + +// DialContext calls DialConfig.DialContext using an empty configuration. +func DialContext(ctx context.Context, path string) (net.Conn, error) { + return defaultDialer.DialContext(ctx, path) +} + type acceptResponse struct { f *file err error @@ -222,12 +230,12 @@ type pipeListener struct { firstHandle windows.Handle path string config ListenConfig - acceptCh chan (chan acceptResponse) + acceptCh chan chan acceptResponse closeCh chan int doneCh chan int } -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) { path16, err := windows.UTF16PtrFromString(path) if err != nil { return 0, &os.PathError{Op: "open", Path: path, Err: err} @@ -247,7 +255,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste oa.ObjectName = &ntPath // The security descriptor is only needed for the first pipe. - if first { + if isFirstPipe { if sd != nil { oa.SecurityDescriptor = sd } else { @@ -257,7 +265,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste return 0, err } defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) - sd, err := windows.NewSecurityDescriptor() + sd, err = windows.NewSecurityDescriptor() if err != nil { return 0, err } @@ -275,11 +283,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste disposition := uint32(windows.FILE_OPEN) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { + if isFirstPipe { disposition = windows.FILE_CREATE // By not asking for read or write access, the named pipe file system // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. + // client connections until the next call with isFirstPipe == false. access = windows.SYNCHRONIZE } @@ -407,12 +415,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) { firstHandle: h, path: path, config: *c, - acceptCh: make(chan (chan acceptResponse)), + acceptCh: make(chan chan acceptResponse), closeCh: make(chan int), doneCh: make(chan int), } // The first connection is swallowed on Windows 7 & 8, so synthesize it. - if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) { path16, err := windows.UTF16PtrFromString(path) if err == nil { h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) diff --git a/ipc/winpipe/winpipe_test.go b/ipc/winpipe/winpipe_test.go index ea515e3..7cde94e 100644 --- a/ipc/winpipe/winpipe_test.go +++ b/ipc/winpipe/winpipe_test.go @@ -1,11 +1,11 @@ -//go:build windows - /* SPDX-License-Identifier: MIT * - * Copyright (C) 2005 Microsoft * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2015 Microsoft */ +//go:build windows + package winpipe_test import ( @@ -22,7 +22,7 @@ import ( "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/winpipe" + "golang.org/x/sys/windows/winpipe" ) func randomPipePath() string { @@ -64,11 +64,12 @@ func TestPingPong(t *testing.T) { t.Fatalf("unable to write pong to pipe: %v", err) } }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatalf("unable to dial pipe: %v", err) } defer client.Close() + client.SetDeadline(time.Now().Add(time.Second * 5)) var data [1]byte data[0] = ping _, err = client.Write(data[:]) @@ -85,7 +86,7 @@ func TestPingPong(t *testing.T) { } func TestDialUnknownFailsImmediately(t *testing.T) { - _, err := winpipe.Dial(randomPipePath(), nil, nil) + _, err := winpipe.DialTimeout(randomPipePath(), time.Duration(0)) if !errors.Is(err, syscall.ENOENT) { t.Fatalf("expected ENOENT got %v", err) } @@ -98,8 +99,10 @@ func TestDialListenerTimesOut(t *testing.T) { t.Fatal(err) } defer l.Close() - d := 10 * time.Millisecond - _, err = winpipe.Dial(pipePath, &d, nil) + pipe, err := winpipe.DialTimeout(pipePath, 10*time.Millisecond) + if err == nil { + pipe.Close() + } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } @@ -114,7 +117,10 @@ func TestDialContextListenerTimesOut(t *testing.T) { defer l.Close() d := 10 * time.Millisecond ctx, _ := context.WithTimeout(context.Background(), d) - _, err = winpipe.DialContext(ctx, pipePath, nil) + pipe, err := winpipe.DialContext(ctx, pipePath) + if err == nil { + pipe.Close() + } if err != context.DeadlineExceeded { t.Fatalf("expected context.DeadlineExceeded, got %v", err) } @@ -127,10 +133,10 @@ func TestDialListenerGetsCancelled(t *testing.T) { if err != nil { t.Fatal(err) } - ch := make(chan error) defer l.Close() + ch := make(chan error) go func(ctx context.Context, ch chan error) { - _, err := winpipe.DialContext(ctx, pipePath, nil) + _, err := winpipe.DialContext(ctx, pipePath) ch <- err }(ctx, ch) time.Sleep(time.Millisecond * 30) @@ -155,7 +161,10 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { t.Fatal(err) } defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) + pipe, err := winpipe.DialTimeout(pipePath, time.Duration(0)) + if err == nil { + pipe.Close() + } if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) } @@ -179,7 +188,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, ch <- response{c, err} }() - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { return } @@ -245,7 +254,7 @@ func TestFullListenDialReadWrite(t *testing.T) { ch := make(chan int) go server(l, ch) - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -374,7 +383,10 @@ func TestDialTimesOutByDefault(t *testing.T) { t.Fatal(err) } defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) + pipe, err := winpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. + if err == nil { + pipe.Close() + } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } @@ -400,7 +412,7 @@ func TestTimeoutPendingRead(t *testing.T) { close(serverDone) }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -448,7 +460,7 @@ func TestTimeoutPendingWrite(t *testing.T) { close(serverDone) }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -496,19 +508,21 @@ func TestEchoWithMessaging(t *testing.T) { clientDone := make(chan bool) go func() { // server echo - conn, e := l.Accept() - if e != nil { - t.Fatal(e) + conn, err := l.Accept() + if err != nil { + t.Fatal(err) } defer conn.Close() time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent - io.Copy(conn, conn) + _, err = io.Copy(conn, conn) + if err != nil { + t.Fatal(err) + } conn.(CloseWriter).CloseWrite() close(listenerDone) }() - timeout := 1 * time.Second - client, err := winpipe.Dial(pipePath, &timeout, nil) + client, err := winpipe.DialTimeout(pipePath, time.Second) if err != nil { t.Fatal(err) } @@ -521,7 +535,7 @@ func TestEchoWithMessaging(t *testing.T) { if e != nil { t.Fatal(e) } - if n != 2 { + if n != 2 || bytes[0] != 0 || bytes[1] != 1 { t.Fatalf("expected 2 bytes, got %v", n) } close(clientDone) @@ -565,7 +579,7 @@ func TestConnectRace(t *testing.T) { }() for i := 0; i < 1000; i++ { - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -602,7 +616,7 @@ func TestMessageReadMode(t *testing.T) { s.Close() }() - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -643,7 +657,7 @@ func TestListenConnectRace(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err == nil { c.Close() } |