diff options
-rw-r--r-- | ipc/uapi_windows.go | 4 | ||||
-rw-r--r-- | ipc/winpipe/file.go | 121 | ||||
-rw-r--r-- | ipc/winpipe/mksyscall.go | 9 | ||||
-rw-r--r-- | ipc/winpipe/pipe.go | 509 | ||||
-rw-r--r-- | ipc/winpipe/winpipe.go | 474 | ||||
-rw-r--r-- | ipc/winpipe/winpipe_test.go | 660 | ||||
-rw-r--r-- | ipc/winpipe/zsyscall_windows.go | 238 |
7 files changed, 1178 insertions, 837 deletions
diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index 164b7cb..3e2709c 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -62,10 +62,10 @@ func init() { } func UAPIListen(name string) (net.Listener, error) { - config := winpipe.PipeConfig{ + config := winpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, } - listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) + listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) if err != nil { return nil, err } diff --git a/ipc/winpipe/file.go b/ipc/winpipe/file.go index f3b768f..0c9abb1 100644 --- a/ipc/winpipe/file.go +++ b/ipc/winpipe/file.go @@ -5,54 +5,21 @@ * Copyright (C) 2005 Microsoft * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. */ + package winpipe import ( - "errors" "io" + "os" "runtime" "sync" "sync/atomic" "time" + "unsafe" "golang.org/x/sys/windows" ) -//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx -//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort -//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus -//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes -//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult - -type atomicBool int32 - -func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } -func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } -func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } -func (b *atomicBool) swap(new bool) bool { - var newInt int32 - if new { - newInt = 1 - } - return atomic.SwapInt32((*int32)(b), newInt) == 1 -} - -const ( - cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1 - cFILE_SKIP_SET_EVENT_ON_HANDLE = 2 -) - -var ( - ErrFileClosed = errors.New("file has already been closed") - ErrTimeout = &timeoutError{} -) - -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - type timeoutChan chan struct{} var ioInitOnce sync.Once @@ -71,7 +38,7 @@ type ioOperation struct { } func initIo() { - h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff) + h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { panic(err) } @@ -79,13 +46,13 @@ func initIo() { go ioCompletionProcessor(h) } -// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. +// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // It takes ownership of this handle and will close it if it is garbage collected. -type win32File struct { +type file struct { handle windows.Handle wg sync.WaitGroup wgLock sync.RWMutex - closing atomicBool + closing uint32 // used as atomic boolean socket bool readDeadline deadlineHandler writeDeadline deadlineHandler @@ -96,18 +63,18 @@ type deadlineHandler struct { channel timeoutChan channelLock sync.RWMutex timer *time.Timer - timedout atomicBool + timedout uint32 // used as atomic boolean } -// makeWin32File makes a new win32File from an existing file handle -func makeWin32File(h windows.Handle) (*win32File, error) { - f := &win32File{handle: h} +// makeFile makes a new file from an existing file handle +func makeFile(h windows.Handle) (*file, error) { + f := &file{handle: h} ioInitOnce.Do(initIo) - _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) + _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) if err != nil { return nil, err } - err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE) + err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) if err != nil { return nil, err } @@ -116,18 +83,14 @@ func makeWin32File(h windows.Handle) (*win32File, error) { return f, nil } -func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) { - return makeWin32File(h) -} - // closeHandle closes the resources associated with a Win32 handle -func (f *win32File) closeHandle() { +func (f *file) closeHandle() { f.wgLock.Lock() // Atomically set that we are closing, releasing the resources only once. - if !f.closing.swap(true) { + if atomic.SwapUint32(&f.closing, 1) == 0 { f.wgLock.Unlock() // cancel all IO and wait for it to complete - cancelIoEx(f.handle, nil) + windows.CancelIoEx(f.handle, nil) f.wg.Wait() // at this point, no new IO can start windows.Close(f.handle) @@ -137,19 +100,19 @@ func (f *win32File) closeHandle() { } } -// Close closes a win32File. -func (f *win32File) Close() error { +// Close closes a file. +func (f *file) Close() error { f.closeHandle() return nil } // prepareIo prepares for a new IO operation. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. -func (f *win32File) prepareIo() (*ioOperation, error) { +func (f *file) prepareIo() (*ioOperation, error) { f.wgLock.RLock() - if f.closing.isSet() { + if atomic.LoadUint32(&f.closing) == 1 { f.wgLock.RUnlock() - return nil, ErrFileClosed + return nil, os.ErrClosed } f.wg.Add(1) f.wgLock.RUnlock() @@ -164,7 +127,7 @@ func ioCompletionProcessor(h windows.Handle) { var bytes uint32 var key uintptr var op *ioOperation - err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE) + err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) if op == nil { panic(err) } @@ -174,13 +137,13 @@ func ioCompletionProcessor(h windows.Handle) { // asyncIo processes the return value from ReadFile or WriteFile, blocking until // the operation has actually completed. -func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { +func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { if err != windows.ERROR_IO_PENDING { return int(bytes), err } - if f.closing.isSet() { - cancelIoEx(f.handle, &c.o) + if atomic.LoadUint32(&f.closing) == 1 { + windows.CancelIoEx(f.handle, &c.o) } var timeout timeoutChan @@ -195,20 +158,20 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er case r = <-c.ch: err = r.err if err == windows.ERROR_OPERATION_ABORTED { - if f.closing.isSet() { - err = ErrFileClosed + if atomic.LoadUint32(&f.closing) == 1 { + err = os.ErrClosed } } else if err != nil && f.socket { // err is from Win32. Query the overlapped structure to get the winsock error. var bytes, flags uint32 - err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) + err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) } case <-timeout: - cancelIoEx(f.handle, &c.o) + windows.CancelIoEx(f.handle, &c.o) r = <-c.ch err = r.err if err == windows.ERROR_OPERATION_ABORTED { - err = ErrTimeout + err = os.ErrDeadlineExceeded } } @@ -220,15 +183,15 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er } // Read reads from a file handle. -func (f *win32File) Read(b []byte) (int, error) { +func (f *file) Read(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() - if f.readDeadline.timedout.isSet() { - return 0, ErrTimeout + if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { + return 0, os.ErrDeadlineExceeded } var bytes uint32 @@ -247,15 +210,15 @@ func (f *win32File) Read(b []byte) (int, error) { } // Write writes to a file handle. -func (f *win32File) Write(b []byte) (int, error) { +func (f *file) Write(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() - if f.writeDeadline.timedout.isSet() { - return 0, ErrTimeout + if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { + return 0, os.ErrDeadlineExceeded } var bytes uint32 @@ -265,19 +228,19 @@ func (f *win32File) Write(b []byte) (int, error) { return n, err } -func (f *win32File) SetReadDeadline(deadline time.Time) error { +func (f *file) SetReadDeadline(deadline time.Time) error { return f.readDeadline.set(deadline) } -func (f *win32File) SetWriteDeadline(deadline time.Time) error { +func (f *file) SetWriteDeadline(deadline time.Time) error { return f.writeDeadline.set(deadline) } -func (f *win32File) Flush() error { +func (f *file) Flush() error { return windows.FlushFileBuffers(f.handle) } -func (f *win32File) Fd() uintptr { +func (f *file) Fd() uintptr { return uintptr(f.handle) } @@ -291,7 +254,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } d.timer = nil } - d.timedout.setFalse() + atomic.StoreUint32(&d.timedout, 0) select { case <-d.channel: @@ -306,7 +269,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } timeoutIO := func() { - d.timedout.setTrue() + atomic.StoreUint32(&d.timedout, 1) close(d.channel) } diff --git a/ipc/winpipe/mksyscall.go b/ipc/winpipe/mksyscall.go deleted file mode 100644 index a87e929..0000000 --- a/ipc/winpipe/mksyscall.go +++ /dev/null @@ -1,9 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go diff --git a/ipc/winpipe/pipe.go b/ipc/winpipe/pipe.go deleted file mode 100644 index e609274..0000000 --- a/ipc/winpipe/pipe.go +++ /dev/null @@ -1,509 +0,0 @@ -// +build windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "os" - "runtime" - "time" - "unsafe" - - "golang.org/x/sys/windows" -) - -//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe -//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW -//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW -//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo -//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW -//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc -//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile -//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb -//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U -//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl - -type ioStatusBlock struct { - Status, Information uintptr -} - -type objectAttributes struct { - Length uintptr - RootDirectory uintptr - ObjectName *unicodeString - Attributes uintptr - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - SecurityQoS uintptr -} - -type unicodeString struct { - Length uint16 - MaximumLength uint16 - Buffer uintptr -} - -type ntstatus int32 - -func (status ntstatus) Err() error { - if status >= 0 { - return nil - } - return rtlNtStatusToDosError(status) -} - -const ( - cSECURITY_SQOS_PRESENT = 0x100000 - cSECURITY_ANONYMOUS = 0 - - cPIPE_TYPE_MESSAGE = 4 - - cPIPE_READMODE_MESSAGE = 2 - - cFILE_OPEN = 1 - cFILE_CREATE = 2 - - cFILE_PIPE_MESSAGE_TYPE = 1 - cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2 -) - -var ( - // ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed. - // This error should match net.errClosing since docker takes a dependency on its text. - ErrPipeListenerClosed = errors.New("use of closed network connection") - - errPipeWriteClosed = errors.New("pipe has been closed for write") -) - -type win32Pipe struct { - *win32File - path string -} - -type win32MessageBytePipe struct { - win32Pipe - writeClosed bool - readEOF bool -} - -type pipeAddress string - -func (f *win32Pipe) LocalAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *win32Pipe) RemoteAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *win32Pipe) SetDeadline(t time.Time) error { - f.SetReadDeadline(t) - f.SetWriteDeadline(t) - return nil -} - -// CloseWrite closes the write side of a message pipe in byte mode. -func (f *win32MessageBytePipe) CloseWrite() error { - if f.writeClosed { - return errPipeWriteClosed - } - err := f.win32File.Flush() - if err != nil { - return err - } - _, err = f.win32File.Write(nil) - if err != nil { - return err - } - f.writeClosed = true - return nil -} - -// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since -// they are used to implement CloseWrite(). -func (f *win32MessageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { - return 0, errPipeWriteClosed - } - if len(b) == 0 { - return 0, nil - } - return f.win32File.Write(b) -} - -// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message -// mode pipe will return io.EOF, as will all subsequent reads. -func (f *win32MessageBytePipe) Read(b []byte) (int, error) { - if f.readEOF { - return 0, io.EOF - } - n, err := f.win32File.Read(b) - if err == io.EOF { - // If this was the result of a zero-byte read, then - // it is possible that the read was due to a zero-size - // message. Since we are simulating CloseWrite with a - // zero-byte message, ensure that all future Read() calls - // also return EOF. - f.readEOF = true - } else if err == windows.ERROR_MORE_DATA { - // ERROR_MORE_DATA indicates that the pipe's read mode is message mode - // and the message still has more bytes. Treat this as a success, since - // this package presents all named pipes as byte streams. - err = nil - } - return n, err -} - -func (s pipeAddress) Network() string { - return "pipe" -} - -func (s pipeAddress) String() string { - return string(s) -} - -// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. -func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { - for { - select { - case <-ctx.Done(): - return windows.Handle(0), ctx.Err() - default: - h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0) - if err == nil { - return h, nil - } - if err != windows.ERROR_PIPE_BUSY { - return h, &os.PathError{Err: err, Op: "open", Path: *path} - } - // Wait 10 msec and try again. This is a rather simplistic - // view, as we always try each 10 milliseconds. - time.Sleep(time.Millisecond * 10) - } - } -} - -// DialPipe connects to a named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. (We do not use WaitNamedPipe.) -func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(time.Second * 2) - } - ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialPipeContext(ctx, path, expectedOwner) - if err == context.DeadlineExceeded { - return nil, ErrTimeout - } - return conn, err -} - -// DialPipeContext attempts to connect to a named pipe by `path` until `ctx` -// cancellation or timeout. -func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) { - var err error - var h windows.Handle - h, err = tryDialPipe(ctx, &path) - if err != nil { - return nil, err - } - - if expectedOwner != nil { - sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err != nil { - windows.Close(h) - return nil, err - } - realOwner, _, err := sd.Owner() - if err != nil { - windows.Close(h) - return nil, err - } - if !realOwner.Equals(expectedOwner) { - windows.Close(h) - return nil, windows.ERROR_ACCESS_DENIED - } - } - - var flags uint32 - err = getNamedPipeInfo(h, &flags, nil, nil, nil) - if err != nil { - windows.Close(h) - return nil, err - } - - f, err := makeWin32File(h) - if err != nil { - windows.Close(h) - return nil, err - } - - // If the pipe is in message mode, return a message byte pipe, which - // supports CloseWrite(). - if flags&cPIPE_TYPE_MESSAGE != 0 { - return &win32MessageBytePipe{ - win32Pipe: win32Pipe{win32File: f, path: path}, - }, nil - } - return &win32Pipe{win32File: f, path: path}, nil -} - -type acceptResponse struct { - f *win32File - err error -} - -type win32PipeListener struct { - firstHandle windows.Handle - path string - config PipeConfig - acceptCh chan (chan acceptResponse) - closeCh chan int - doneCh chan int -} - -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) { - path16, err := windows.UTF16FromString(path) - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - var oa objectAttributes - oa.Length = unsafe.Sizeof(oa) - - var ntPath unicodeString - if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - defer windows.LocalFree(windows.Handle(ntPath.Buffer)) - oa.ObjectName = &ntPath - - // The security descriptor is only needed for the first pipe. - if first { - if sd != nil { - oa.SecurityDescriptor = sd - } else { - // Construct the default named pipe security descriptor. - var dacl uintptr - if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { - return 0, fmt.Errorf("getting default named pipe ACL: %s", err) - } - defer windows.LocalFree(windows.Handle(dacl)) - sd, err := windows.NewSecurityDescriptor() - if err != nil { - return 0, fmt.Errorf("creating new security descriptor: %s", err) - } - if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil { - return 0, fmt.Errorf("assigning dacl: %s", err) - } - sd, err = sd.ToSelfRelative() - if err != nil { - return 0, fmt.Errorf("converting to self-relative: %s", err) - } - oa.SecurityDescriptor = sd - } - } - - typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS) - if c.MessageMode { - typ |= cFILE_PIPE_MESSAGE_TYPE - } - - disposition := uint32(cFILE_OPEN) - access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { - disposition = cFILE_CREATE - // By not asking for read or write access, the named pipe file system - // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. - access = windows.SYNCHRONIZE - } - - timeout := int64(-50 * 10000) // 50ms - - var ( - h windows.Handle - iosb ioStatusBlock - ) - err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err() - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - runtime.KeepAlive(ntPath) - return h, nil -} - -func (l *win32PipeListener) makeServerPipe() (*win32File, error) { - h, err := makeServerPipeHandle(l.path, nil, &l.config, false) - if err != nil { - return nil, err - } - f, err := makeWin32File(h) - if err != nil { - windows.Close(h) - return nil, err - } - return f, nil -} - -func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) { - p, err := l.makeServerPipe() - if err != nil { - return nil, err - } - - // Wait for the client to connect. - ch := make(chan error) - go func(p *win32File) { - ch <- connectPipe(p) - }(p) - - select { - case err = <-ch: - if err != nil { - p.Close() - p = nil - } - case <-l.closeCh: - // Abort the connect request by closing the handle. - p.Close() - p = nil - err = <-ch - if err == nil || err == ErrFileClosed { - err = ErrPipeListenerClosed - } - } - return p, err -} - -func (l *win32PipeListener) listenerRoutine() { - closed := false - for !closed { - select { - case <-l.closeCh: - closed = true - case responseCh := <-l.acceptCh: - var ( - p *win32File - err error - ) - for { - p, err = l.makeConnectedServerPipe() - // If the connection was immediately closed by the client, try - // again. - if err != windows.ERROR_NO_DATA { - break - } - } - responseCh <- acceptResponse{p, err} - closed = err == ErrPipeListenerClosed - } - } - windows.Close(l.firstHandle) - l.firstHandle = 0 - // Notify Close() and Accept() callers that the handle has been closed. - close(l.doneCh) -} - -// PipeConfig contain configuration for the pipe listener. -type PipeConfig struct { - // SecurityDescriptor contains a Windows security descriptor. - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - - // MessageMode determines whether the pipe is in byte or message mode. In either - // case the pipe is read in byte mode by default. The only practical difference in - // this implementation is that CloseWrite() is only supported for message mode pipes; - // CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only - // transferred to the reader (and returned as io.EOF in this implementation) - // when the pipe is in message mode. - MessageMode bool - - // InputBufferSize specifies the size the input buffer, in bytes. - InputBufferSize int32 - - // OutputBufferSize specifies the size the input buffer, in bytes. - OutputBufferSize int32 -} - -// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe. -// The pipe must not already exist. -func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { - if c == nil { - c = &PipeConfig{} - } - h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) - if err != nil { - return nil, err - } - l := &win32PipeListener{ - firstHandle: h, - path: path, - config: *c, - acceptCh: make(chan (chan acceptResponse)), - closeCh: make(chan int), - doneCh: make(chan int), - } - go l.listenerRoutine() - return l, nil -} - -func connectPipe(p *win32File) error { - c, err := p.prepareIo() - if err != nil { - return err - } - defer p.wg.Done() - - err = connectNamedPipe(p.handle, &c.o) - _, err = p.asyncIo(c, nil, 0, err) - if err != nil && err != windows.ERROR_PIPE_CONNECTED { - return err - } - return nil -} - -func (l *win32PipeListener) Accept() (net.Conn, error) { - ch := make(chan acceptResponse) - select { - case l.acceptCh <- ch: - response := <-ch - err := response.err - if err != nil { - return nil, err - } - if l.config.MessageMode { - return &win32MessageBytePipe{ - win32Pipe: win32Pipe{win32File: response.f, path: l.path}, - }, nil - } - return &win32Pipe{win32File: response.f, path: l.path}, nil - case <-l.doneCh: - return nil, ErrPipeListenerClosed - } -} - -func (l *win32PipeListener) Close() error { - select { - case l.closeCh <- 1: - <-l.doneCh - case <-l.doneCh: - } - return nil -} - -func (l *win32PipeListener) Addr() net.Addr { - return pipeAddress(l.path) -} diff --git a/ipc/winpipe/winpipe.go b/ipc/winpipe/winpipe.go new file mode 100644 index 0000000..f02f3d8 --- /dev/null +++ b/ipc/winpipe/winpipe.go @@ -0,0 +1,474 @@ +// +build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2005 Microsoft + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +// Package winpipe implements a net.Conn and net.Listener around Windows named pipes. +package winpipe + +import ( + "context" + "io" + "net" + "os" + "runtime" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +type pipe struct { + *file + path string +} + +type messageBytePipe struct { + pipe + writeClosed bool + readEOF bool +} + +type pipeAddress string + +func (f *pipe) LocalAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) RemoteAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) SetDeadline(t time.Time) error { + f.SetReadDeadline(t) + f.SetWriteDeadline(t) + return nil +} + +// CloseWrite closes the write side of a message pipe in byte mode. +func (f *messageBytePipe) CloseWrite() error { + if f.writeClosed { + return io.ErrClosedPipe + } + err := f.file.Flush() + if err != nil { + return err + } + _, err = f.file.Write(nil) + if err != nil { + return err + } + f.writeClosed = true + return nil +} + +// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since +// they are used to implement CloseWrite. +func (f *messageBytePipe) Write(b []byte) (int, error) { + if f.writeClosed { + return 0, io.ErrClosedPipe + } + if len(b) == 0 { + return 0, nil + } + return f.file.Write(b) +} + +// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message +// mode pipe will return io.EOF, as will all subsequent reads. +func (f *messageBytePipe) Read(b []byte) (int, error) { + if f.readEOF { + return 0, io.EOF + } + n, err := f.file.Read(b) + if err == io.EOF { + // If this was the result of a zero-byte read, then + // it is possible that the read was due to a zero-size + // message. Since we are simulating CloseWrite with a + // zero-byte message, ensure that all future Read calls + // also return EOF. + f.readEOF = true + } else if err == windows.ERROR_MORE_DATA { + // ERROR_MORE_DATA indicates that the pipe's read mode is message mode + // and the message still has more bytes. Treat this as a success, since + // this package presents all named pipes as byte streams. + err = nil + } + return n, err +} + +func (f *pipe) Handle() windows.Handle { + return f.handle +} + +func (s pipeAddress) Network() string { + return "pipe" +} + +func (s pipeAddress) String() string { + return string(s) +} + +// tryDialPipe attempts to dial the specified pipe until cancellation or timeout. +func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { + for { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + path16, err := windows.UTF16PtrFromString(*path) + if err != nil { + return 0, err + } + h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + return h, nil + } + if err != windows.ERROR_PIPE_BUSY { + return h, &os.PathError{Err: err, Op: "open", Path: *path} + } + // Wait 10 msec and try again. This is a rather simplistic + // view, as we always try each 10 milliseconds. + time.Sleep(10 * time.Millisecond) + } + } +} + +// DialConfig exposes various options for use in Dial and DialContext. +type DialConfig struct { + ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. +} + +// Dial connects to the specified named pipe by path, timing out if the connection +// takes longer than the specified duration. If timeout is nil, then we use +// a default timeout of 2 seconds. +func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { + var absTimeout time.Time + if timeout != nil { + absTimeout = time.Now().Add(*timeout) + } else { + absTimeout = time.Now().Add(2 * time.Second) + } + ctx, _ := context.WithDeadline(context.Background(), absTimeout) + conn, err := DialContext(ctx, path, config) + if err == context.DeadlineExceeded { + return nil, os.ErrDeadlineExceeded + } + return conn, err +} + +// DialContext attempts to connect to the specified named pipe by path +// cancellation or timeout. +func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) { + if config == nil { + config = &DialConfig{} + } + var err error + var h windows.Handle + h, err = tryDialPipe(ctx, &path) + if err != nil { + return nil, err + } + + if config.ExpectedOwner != nil { + sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) + if err != nil { + windows.Close(h) + return nil, err + } + realOwner, _, err := sd.Owner() + if err != nil { + windows.Close(h) + return nil, err + } + if !realOwner.Equals(config.ExpectedOwner) { + windows.Close(h) + return nil, windows.ERROR_ACCESS_DENIED + } + } + + var flags uint32 + err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil) + if err != nil { + windows.Close(h) + return nil, err + } + + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + + // If the pipe is in message mode, return a message byte pipe, which + // supports CloseWrite. + if flags&windows.PIPE_TYPE_MESSAGE != 0 { + return &messageBytePipe{ + pipe: pipe{file: f, path: path}, + }, nil + } + return &pipe{file: f, path: path}, nil +} + +type acceptResponse struct { + f *file + err error +} + +type pipeListener struct { + firstHandle windows.Handle + path string + config ListenConfig + acceptCh chan (chan acceptResponse) + closeCh chan int + doneCh chan int +} + +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { + path16, err := windows.UTF16PtrFromString(path) + if err != nil { + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + var oa windows.OBJECT_ATTRIBUTES + oa.Length = uint32(unsafe.Sizeof(oa)) + + var ntPath windows.NTUnicodeString + if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer))) + oa.ObjectName = &ntPath + + // The security descriptor is only needed for the first pipe. + if first { + if sd != nil { + oa.SecurityDescriptor = sd + } else { + // Construct the default named pipe security descriptor. + var acl *windows.ACL + if err := windows.RtlDefaultNpAcl(&acl); err != nil { + return 0, err + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) + sd, err := windows.NewSecurityDescriptor() + if err != nil { + return 0, err + } + if err = sd.SetDACL(acl, true, false); err != nil { + return 0, err + } + oa.SecurityDescriptor = sd + } + } + + typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) + if c.MessageMode { + typ |= windows.FILE_PIPE_MESSAGE_TYPE + } + + disposition := uint32(windows.FILE_OPEN) + access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) + if first { + disposition = windows.FILE_CREATE + // By not asking for read or write access, the named pipe file system + // will put this pipe into an initially disconnected state, blocking + // client connections until the next call with first == false. + access = windows.SYNCHRONIZE + } + + timeout := int64(-50 * 10000) // 50ms + + var ( + h windows.Handle + iosb windows.IO_STATUS_BLOCK + ) + err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + runtime.KeepAlive(ntPath) + return h, nil +} + +func (l *pipeListener) makeServerPipe() (*file, error) { + h, err := makeServerPipeHandle(l.path, nil, &l.config, false) + if err != nil { + return nil, err + } + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + return f, nil +} + +func (l *pipeListener) makeConnectedServerPipe() (*file, error) { + p, err := l.makeServerPipe() + if err != nil { + return nil, err + } + + // Wait for the client to connect. + ch := make(chan error) + go func(p *file) { + ch <- connectPipe(p) + }(p) + + select { + case err = <-ch: + if err != nil { + p.Close() + p = nil + } + case <-l.closeCh: + // Abort the connect request by closing the handle. + p.Close() + p = nil + err = <-ch + if err == nil || err == os.ErrClosed { + err = net.ErrClosed + } + } + return p, err +} + +func (l *pipeListener) listenerRoutine() { + closed := false + for !closed { + select { + case <-l.closeCh: + closed = true + case responseCh := <-l.acceptCh: + var ( + p *file + err error + ) + for { + p, err = l.makeConnectedServerPipe() + // If the connection was immediately closed by the client, try + // again. + if err != windows.ERROR_NO_DATA { + break + } + } + responseCh <- acceptResponse{p, err} + closed = err == net.ErrClosed + } + } + windows.Close(l.firstHandle) + l.firstHandle = 0 + // Notify Close and Accept callers that the handle has been closed. + close(l.doneCh) +} + +// ListenConfig contains configuration for the pipe listener. +type ListenConfig struct { + // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used. + SecurityDescriptor *windows.SECURITY_DESCRIPTOR + + // MessageMode determines whether the pipe is in byte or message mode. In either + // case the pipe is read in byte mode by default. The only practical difference in + // this implementation is that CloseWrite is only supported for message mode pipes; + // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only + // transferred to the reader (and returned as io.EOF in this implementation) + // when the pipe is in message mode. + MessageMode bool + + // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed. + InputBufferSize int32 + + // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed. + OutputBufferSize int32 +} + +// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. +// The pipe must not already exist. +func Listen(path string, c *ListenConfig) (net.Listener, error) { + if c == nil { + c = &ListenConfig{} + } + h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) + if err != nil { + return nil, err + } + l := &pipeListener{ + firstHandle: h, + path: path, + config: *c, + acceptCh: make(chan (chan acceptResponse)), + closeCh: make(chan int), + doneCh: make(chan int), + } + // The first connection is swallowed on Windows 7 & 8, so synthesize it. + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + path16, err := windows.UTF16PtrFromString(path) + if err == nil { + h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + windows.CloseHandle(h) + } + } + } + go l.listenerRoutine() + return l, nil +} + +func connectPipe(p *file) error { + c, err := p.prepareIo() + if err != nil { + return err + } + defer p.wg.Done() + + err = windows.ConnectNamedPipe(p.handle, &c.o) + _, err = p.asyncIo(c, nil, 0, err) + if err != nil && err != windows.ERROR_PIPE_CONNECTED { + return err + } + return nil +} + +func (l *pipeListener) Accept() (net.Conn, error) { + ch := make(chan acceptResponse) + select { + case l.acceptCh <- ch: + response := <-ch + err := response.err + if err != nil { + return nil, err + } + if l.config.MessageMode { + return &messageBytePipe{ + pipe: pipe{file: response.f, path: l.path}, + }, nil + } + return &pipe{file: response.f, path: l.path}, nil + case <-l.doneCh: + return nil, net.ErrClosed + } +} + +func (l *pipeListener) Close() error { + select { + case l.closeCh <- 1: + <-l.doneCh + case <-l.doneCh: + } + return nil +} + +func (l *pipeListener) Addr() net.Addr { + return pipeAddress(l.path) +} diff --git a/ipc/winpipe/winpipe_test.go b/ipc/winpipe/winpipe_test.go new file mode 100644 index 0000000..ee8dc8c --- /dev/null +++ b/ipc/winpipe/winpipe_test.go @@ -0,0 +1,660 @@ +// +build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2005 Microsoft + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package winpipe_test + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net" + "os" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/ipc/winpipe" +) + +func randomPipePath() string { + guid, err := windows.GenerateGUID() + if err != nil { + panic(err) + } + return `\\.\PIPE\go-winpipe-test-` + guid.String() +} + +func TestPingPong(t *testing.T) { + const ( + ping = 42 + pong = 24 + ) + pipePath := randomPipePath() + listener, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatalf("unable to listen on pipe: %v", err) + } + defer listener.Close() + go func() { + incoming, err := listener.Accept() + if err != nil { + t.Fatalf("unable to accept pipe connection: %v", err) + } + defer incoming.Close() + var data [1]byte + _, err = incoming.Read(data[:]) + if err != nil { + t.Fatalf("unable to read ping from pipe: %v", err) + } + if data[0] != ping { + t.Fatalf("expected ping, got %d", data[0]) + } + data[0] = pong + _, err = incoming.Write(data[:]) + if err != nil { + t.Fatalf("unable to write pong to pipe: %v", err) + } + }() + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatalf("unable to dial pipe: %v", err) + } + defer client.Close() + var data [1]byte + data[0] = ping + _, err = client.Write(data[:]) + if err != nil { + t.Fatalf("unable to write ping to pipe: %v", err) + } + _, err = client.Read(data[:]) + if err != nil { + t.Fatalf("unable to read pong from pipe: %v", err) + } + if data[0] != pong { + t.Fatalf("expected pong, got %d", data[0]) + } +} + +func TestDialUnknownFailsImmediately(t *testing.T) { + _, err := winpipe.Dial(randomPipePath(), nil, nil) + if !errors.Is(err, syscall.ENOENT) { + t.Fatalf("expected ENOENT got %v", err) + } +} + +func TestDialListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + d := 10 * time.Millisecond + _, err = winpipe.Dial(pipePath, &d, nil) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestDialContextListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + d := 10 * time.Millisecond + ctx, _ := context.WithTimeout(context.Background(), d) + _, err = winpipe.DialContext(ctx, pipePath, nil) + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestDialListenerGetsCancelled(t *testing.T) { + pipePath := randomPipePath() + ctx, cancel := context.WithCancel(context.Background()) + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + ch := make(chan error) + defer l.Close() + go func(ctx context.Context, ch chan error) { + _, err := winpipe.DialContext(ctx, pipePath, nil) + ch <- err + }(ctx, ch) + time.Sleep(time.Millisecond * 30) + cancel() + err = <-ch + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { + if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil { + t.Skip("dacls on named pipes are broken on wine") + } + pipePath := randomPipePath() + sd, _ := windows.SecurityDescriptorFromString("D:") + c := winpipe.ListenConfig{ + SecurityDescriptor: sd, + } + l, err := winpipe.Listen(pipePath, &c) + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, err = winpipe.Dial(pipePath, nil, nil) + if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { + t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) + } +} + +func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, cfg) + if err != nil { + return + } + defer l.Close() + + type response struct { + c net.Conn + err error + } + ch := make(chan response) + go func() { + c, err := l.Accept() + ch <- response{c, err} + }() + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + return + } + + r := <-ch + if err = r.err; err != nil { + c.Close() + return + } + + client = c + server = r.c + return +} + +func TestReadTimeout(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + + buf := make([]byte, 10) + _, err = c.Read(buf) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func server(l net.Listener, ch chan int) { + c, err := l.Accept() + if err != nil { + panic(err) + } + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + s, err := rw.ReadString('\n') + if err != nil { + panic(err) + } + _, err = rw.WriteString("got " + s) + if err != nil { + panic(err) + } + err = rw.Flush() + if err != nil { + panic(err) + } + c.Close() + ch <- 1 +} + +func TestFullListenDialReadWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ch := make(chan int) + go server(l, ch) + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + _, err = rw.WriteString("hello world\n") + if err != nil { + t.Fatal(err) + } + err = rw.Flush() + if err != nil { + t.Fatal(err) + } + + s, err := rw.ReadString('\n') + if err != nil { + t.Fatal(err) + } + ms := "got hello world\n" + if s != ms { + t.Errorf("expected '%s', got '%s'", ms, s) + } + + <-ch +} + +func TestCloseAbortsListen(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + + ch := make(chan error) + go func() { + _, err := l.Accept() + ch <- err + }() + + time.Sleep(30 * time.Millisecond) + l.Close() + + err = <-ch + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { + b := make([]byte, 10) + w.Close() + n, err := r.Read(b) + if n > 0 { + t.Errorf("unexpected byte count %d", n) + } + if err != io.EOF { + t.Errorf("expected EOF: %v", err) + } +} + +func TestCloseClientEOFServer(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, c, s) +} + +func TestCloseServerEOFClient(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, s, c) +} + +func TestCloseWriteEOF(t *testing.T) { + cfg := &winpipe.ListenConfig{ + MessageMode: true, + } + c, s, err := getConnection(cfg) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + type closeWriter interface { + CloseWrite() error + } + + err = c.(closeWriter).CloseWrite() + if err != nil { + t.Fatal(err) + } + + b := make([]byte, 10) + _, err = s.Read(b) + if err != io.EOF { + t.Fatal(err) + } +} + +func TestAcceptAfterCloseFails(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + l.Close() + _, err = l.Accept() + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func TestDialTimesOutByDefault(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, err = winpipe.Dial(pipePath, nil, nil) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestTimeoutPendingRead(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + buf := make([]byte, 10) + _, err = client.Read(buf) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline + client.SetReadDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for read to cancel") + <-clientErr + } + <-serverDone +} + +func TestTimeoutPendingWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + _, err = client.Write([]byte("this should timeout")) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline + client.SetWriteDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for write to cancel") + <-clientErr + } + <-serverDone +} + +type CloseWriter interface { + CloseWrite() error +} + +func TestEchoWithMessaging(t *testing.T) { + c := winpipe.ListenConfig{ + MessageMode: true, // Use message mode so that CloseWrite() is supported + InputBufferSize: 65536, // Use 64KB buffers to improve performance + OutputBufferSize: 65536, + } + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, &c) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + listenerDone := make(chan bool) + clientDone := make(chan bool) + go func() { + // server echo + conn, e := l.Accept() + if e != nil { + t.Fatal(e) + } + defer conn.Close() + + time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent + io.Copy(conn, conn) + conn.(CloseWriter).CloseWrite() + close(listenerDone) + }() + timeout := 1 * time.Second + client, err := winpipe.Dial(pipePath, &timeout, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + go func() { + // client read back + bytes := make([]byte, 2) + n, e := client.Read(bytes) + if e != nil { + t.Fatal(e) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %v", n) + } + close(clientDone) + }() + + payload := make([]byte, 2) + payload[0] = 0 + payload[1] = 1 + + n, err := client.Write(payload) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %v", n) + } + client.(CloseWriter).CloseWrite() + <-listenerDone + <-clientDone +} + +func TestConnectRace(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + go func() { + for { + s, err := l.Accept() + if err == net.ErrClosed { + return + } + + if err != nil { + t.Fatal(err) + } + s.Close() + } + }() + + for i := 0; i < 1000; i++ { + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + c.Close() + } +} + +func TestMessageReadMode(t *testing.T) { + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + t.Skipf("Skipping on Windows %d", maj) + } + var wg sync.WaitGroup + defer wg.Wait() + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true}) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + msg := ([]byte)("hello world") + + wg.Add(1) + go func() { + defer wg.Done() + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + _, err = s.Write(msg) + if err != nil { + t.Fatal(err) + } + s.Close() + }() + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + mode := uint32(windows.PIPE_READMODE_MESSAGE) + err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) + if err != nil { + t.Fatal(err) + } + + ch := make([]byte, 1) + var vmsg []byte + for { + n, err := c.Read(ch) + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected 1, got %d", n) + } + vmsg = append(vmsg, ch[0]) + } + if !bytes.Equal(msg, vmsg) { + t.Fatalf("expected %s, got %s", msg, vmsg) + } +} + +func TestListenConnectRace(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long race test") + } + pipePath := randomPipePath() + for i := 0; i < 50 && !t.Failed(); i++ { + var wg sync.WaitGroup + wg.Add(1) + go func() { + c, err := winpipe.Dial(pipePath, nil, nil) + if err == nil { + c.Close() + } + wg.Done() + }() + s, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Error(i, err) + } else { + s.Close() + } + wg.Wait() + } +} diff --git a/ipc/winpipe/zsyscall_windows.go b/ipc/winpipe/zsyscall_windows.go deleted file mode 100644 index 9954329..0000000 --- a/ipc/winpipe/zsyscall_windows.go +++ /dev/null @@ -1,238 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package winpipe - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modntdll = windows.NewLazySystemDLL("ntdll.dll") - modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") - - procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") - procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") - procCreateFileW = modkernel32.NewProc("CreateFileW") - procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") - procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") - procLocalAlloc = modkernel32.NewProc("LocalAlloc") - procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile") - procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") - procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") - procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl") - procCancelIoEx = modkernel32.NewProc("CancelIoEx") - procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") - procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") - procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes") - procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult") -) - -func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) { - r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { - var _p0 *uint16 - _p0, err = syscall.UTF16PtrFromString(name) - if err != nil { - return - } - return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa) -} - -func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) { - var _p0 *uint16 - _p0, err = syscall.UTF16PtrFromString(name) - if err != nil { - return - } - return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile) -} - -func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) { - r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func localAlloc(uFlags uint32, length uint32) (ptr uintptr) { - r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0) - ptr = uintptr(r0) - return -} - -func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) { - r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) - status = ntstatus(r0) - return -} - -func rtlNtStatusToDosError(status ntstatus) (winerr error) { - r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0) - if r0 != 0 { - winerr = syscall.Errno(r0) - } - return -} - -func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) { - r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0) - status = ntstatus(r0) - return -} - -func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) { - r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0) - status = ntstatus(r0) - return -} - -func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) { - r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) { - r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0) - newport = windows.Handle(r0) - if newport == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) { - r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) { - var _p0 uint32 - if wait { - _p0 = 1 - } else { - _p0 = 0 - } - r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} |