diff options
Diffstat (limited to 'ipc/winpipe/winpipe.go')
-rw-r--r-- | ipc/winpipe/winpipe.go | 70 |
1 files changed, 39 insertions, 31 deletions
diff --git a/ipc/winpipe/winpipe.go b/ipc/winpipe/winpipe.go index e3719d6..d34ba82 100644 --- a/ipc/winpipe/winpipe.go +++ b/ipc/winpipe/winpipe.go @@ -1,11 +1,11 @@ -//go:build windows - /* SPDX-License-Identifier: MIT * - * Copyright (C) 2005 Microsoft * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2015 Microsoft */ +//go:build windows + // Package winpipe implements a net.Conn and net.Listener around Windows named pipes. package winpipe @@ -15,6 +15,7 @@ import ( "net" "os" "runtime" + "sync/atomic" "time" "unsafe" @@ -28,7 +29,7 @@ type pipe struct { type messageBytePipe struct { pipe - writeClosed bool + writeClosed int32 readEOF bool } @@ -50,25 +51,26 @@ func (f *pipe) SetDeadline(t time.Time) error { // CloseWrite closes the write side of a message pipe in byte mode. func (f *messageBytePipe) CloseWrite() error { - if f.writeClosed { + if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) { return io.ErrClosedPipe } err := f.file.Flush() if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) return err } _, err = f.file.Write(nil) if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) return err } - f.writeClosed = true return nil } // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // they are used to implement CloseWrite. func (f *messageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { + if atomic.LoadInt32(&f.writeClosed) != 0 { return 0, io.ErrClosedPipe } if len(b) == 0 { @@ -142,30 +144,24 @@ type DialConfig struct { ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. } -// Dial connects to the specified named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. -func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(2 * time.Second) +// DialTimeout connects to the specified named pipe by path, timing out if the +// connection takes longer than the specified duration. If timeout is zero, then +// we use a default timeout of 2 seconds. +func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + if timeout == 0 { + timeout = time.Second * 2 } + absTimeout := time.Now().Add(timeout) ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialContext(ctx, path, config) + conn, err := config.DialContext(ctx, path) if err == context.DeadlineExceeded { return nil, os.ErrDeadlineExceeded } return conn, err } -// DialContext attempts to connect to the specified named pipe by path -// cancellation or timeout. -func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) { - if config == nil { - config = &DialConfig{} - } +// DialContext attempts to connect to the specified named pipe by path. +func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) { var err error var h windows.Handle h, err = tryDialPipe(ctx, &path) @@ -213,6 +209,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn return &pipe{file: f, path: path}, nil } +var defaultDialer DialConfig + +// DialTimeout calls DialConfig.DialTimeout using an empty configuration. +func DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + return defaultDialer.DialTimeout(path, timeout) +} + +// DialContext calls DialConfig.DialContext using an empty configuration. +func DialContext(ctx context.Context, path string) (net.Conn, error) { + return defaultDialer.DialContext(ctx, path) +} + type acceptResponse struct { f *file err error @@ -222,12 +230,12 @@ type pipeListener struct { firstHandle windows.Handle path string config ListenConfig - acceptCh chan (chan acceptResponse) + acceptCh chan chan acceptResponse closeCh chan int doneCh chan int } -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) { path16, err := windows.UTF16PtrFromString(path) if err != nil { return 0, &os.PathError{Op: "open", Path: path, Err: err} @@ -247,7 +255,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste oa.ObjectName = &ntPath // The security descriptor is only needed for the first pipe. - if first { + if isFirstPipe { if sd != nil { oa.SecurityDescriptor = sd } else { @@ -257,7 +265,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste return 0, err } defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) - sd, err := windows.NewSecurityDescriptor() + sd, err = windows.NewSecurityDescriptor() if err != nil { return 0, err } @@ -275,11 +283,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste disposition := uint32(windows.FILE_OPEN) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { + if isFirstPipe { disposition = windows.FILE_CREATE // By not asking for read or write access, the named pipe file system // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. + // client connections until the next call with isFirstPipe == false. access = windows.SYNCHRONIZE } @@ -407,12 +415,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) { firstHandle: h, path: path, config: *c, - acceptCh: make(chan (chan acceptResponse)), + acceptCh: make(chan chan acceptResponse), closeCh: make(chan int), doneCh: make(chan int), } // The first connection is swallowed on Windows 7 & 8, so synthesize it. - if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) { path16, err := windows.UTF16PtrFromString(path) if err == nil { h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) |