summaryrefslogtreecommitdiffhomepage
path: root/ipc/winpipe/winpipe_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'ipc/winpipe/winpipe_test.go')
-rw-r--r--ipc/winpipe/winpipe_test.go68
1 files changed, 41 insertions, 27 deletions
diff --git a/ipc/winpipe/winpipe_test.go b/ipc/winpipe/winpipe_test.go
index ea515e3..7cde94e 100644
--- a/ipc/winpipe/winpipe_test.go
+++ b/ipc/winpipe/winpipe_test.go
@@ -1,11 +1,11 @@
-//go:build windows
-
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2015 Microsoft
*/
+//go:build windows
+
package winpipe_test
import (
@@ -22,7 +22,7 @@ import (
"time"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/ipc/winpipe"
+ "golang.org/x/sys/windows/winpipe"
)
func randomPipePath() string {
@@ -64,11 +64,12 @@ func TestPingPong(t *testing.T) {
t.Fatalf("unable to write pong to pipe: %v", err)
}
}()
- client, err := winpipe.Dial(pipePath, nil, nil)
+ client, err := winpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatalf("unable to dial pipe: %v", err)
}
defer client.Close()
+ client.SetDeadline(time.Now().Add(time.Second * 5))
var data [1]byte
data[0] = ping
_, err = client.Write(data[:])
@@ -85,7 +86,7 @@ func TestPingPong(t *testing.T) {
}
func TestDialUnknownFailsImmediately(t *testing.T) {
- _, err := winpipe.Dial(randomPipePath(), nil, nil)
+ _, err := winpipe.DialTimeout(randomPipePath(), time.Duration(0))
if !errors.Is(err, syscall.ENOENT) {
t.Fatalf("expected ENOENT got %v", err)
}
@@ -98,8 +99,10 @@ func TestDialListenerTimesOut(t *testing.T) {
t.Fatal(err)
}
defer l.Close()
- d := 10 * time.Millisecond
- _, err = winpipe.Dial(pipePath, &d, nil)
+ pipe, err := winpipe.DialTimeout(pipePath, 10*time.Millisecond)
+ if err == nil {
+ pipe.Close()
+ }
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
@@ -114,7 +117,10 @@ func TestDialContextListenerTimesOut(t *testing.T) {
defer l.Close()
d := 10 * time.Millisecond
ctx, _ := context.WithTimeout(context.Background(), d)
- _, err = winpipe.DialContext(ctx, pipePath, nil)
+ pipe, err := winpipe.DialContext(ctx, pipePath)
+ if err == nil {
+ pipe.Close()
+ }
if err != context.DeadlineExceeded {
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
}
@@ -127,10 +133,10 @@ func TestDialListenerGetsCancelled(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- ch := make(chan error)
defer l.Close()
+ ch := make(chan error)
go func(ctx context.Context, ch chan error) {
- _, err := winpipe.DialContext(ctx, pipePath, nil)
+ _, err := winpipe.DialContext(ctx, pipePath)
ch <- err
}(ctx, ch)
time.Sleep(time.Millisecond * 30)
@@ -155,7 +161,10 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
t.Fatal(err)
}
defer l.Close()
- _, err = winpipe.Dial(pipePath, nil, nil)
+ pipe, err := winpipe.DialTimeout(pipePath, time.Duration(0))
+ if err == nil {
+ pipe.Close()
+ }
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
}
@@ -179,7 +188,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn,
ch <- response{c, err}
}()
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := winpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
return
}
@@ -245,7 +254,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
ch := make(chan int)
go server(l, ch)
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := winpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -374,7 +383,10 @@ func TestDialTimesOutByDefault(t *testing.T) {
t.Fatal(err)
}
defer l.Close()
- _, err = winpipe.Dial(pipePath, nil, nil)
+ pipe, err := winpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
+ if err == nil {
+ pipe.Close()
+ }
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
@@ -400,7 +412,7 @@ func TestTimeoutPendingRead(t *testing.T) {
close(serverDone)
}()
- client, err := winpipe.Dial(pipePath, nil, nil)
+ client, err := winpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -448,7 +460,7 @@ func TestTimeoutPendingWrite(t *testing.T) {
close(serverDone)
}()
- client, err := winpipe.Dial(pipePath, nil, nil)
+ client, err := winpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -496,19 +508,21 @@ func TestEchoWithMessaging(t *testing.T) {
clientDone := make(chan bool)
go func() {
// server echo
- conn, e := l.Accept()
- if e != nil {
- t.Fatal(e)
+ conn, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
}
defer conn.Close()
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
- io.Copy(conn, conn)
+ _, err = io.Copy(conn, conn)
+ if err != nil {
+ t.Fatal(err)
+ }
conn.(CloseWriter).CloseWrite()
close(listenerDone)
}()
- timeout := 1 * time.Second
- client, err := winpipe.Dial(pipePath, &timeout, nil)
+ client, err := winpipe.DialTimeout(pipePath, time.Second)
if err != nil {
t.Fatal(err)
}
@@ -521,7 +535,7 @@ func TestEchoWithMessaging(t *testing.T) {
if e != nil {
t.Fatal(e)
}
- if n != 2 {
+ if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
t.Fatalf("expected 2 bytes, got %v", n)
}
close(clientDone)
@@ -565,7 +579,7 @@ func TestConnectRace(t *testing.T) {
}()
for i := 0; i < 1000; i++ {
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := winpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -602,7 +616,7 @@ func TestMessageReadMode(t *testing.T) {
s.Close()
}()
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := winpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -643,7 +657,7 @@ func TestListenConnectRace(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := winpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
c.Close()
}