diff options
Diffstat (limited to 'ipc/winpipe/winpipe_test.go')
-rw-r--r-- | ipc/winpipe/winpipe_test.go | 68 |
1 files changed, 41 insertions, 27 deletions
diff --git a/ipc/winpipe/winpipe_test.go b/ipc/winpipe/winpipe_test.go index ea515e3..7cde94e 100644 --- a/ipc/winpipe/winpipe_test.go +++ b/ipc/winpipe/winpipe_test.go @@ -1,11 +1,11 @@ -//go:build windows - /* SPDX-License-Identifier: MIT * - * Copyright (C) 2005 Microsoft * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2015 Microsoft */ +//go:build windows + package winpipe_test import ( @@ -22,7 +22,7 @@ import ( "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/winpipe" + "golang.org/x/sys/windows/winpipe" ) func randomPipePath() string { @@ -64,11 +64,12 @@ func TestPingPong(t *testing.T) { t.Fatalf("unable to write pong to pipe: %v", err) } }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatalf("unable to dial pipe: %v", err) } defer client.Close() + client.SetDeadline(time.Now().Add(time.Second * 5)) var data [1]byte data[0] = ping _, err = client.Write(data[:]) @@ -85,7 +86,7 @@ func TestPingPong(t *testing.T) { } func TestDialUnknownFailsImmediately(t *testing.T) { - _, err := winpipe.Dial(randomPipePath(), nil, nil) + _, err := winpipe.DialTimeout(randomPipePath(), time.Duration(0)) if !errors.Is(err, syscall.ENOENT) { t.Fatalf("expected ENOENT got %v", err) } @@ -98,8 +99,10 @@ func TestDialListenerTimesOut(t *testing.T) { t.Fatal(err) } defer l.Close() - d := 10 * time.Millisecond - _, err = winpipe.Dial(pipePath, &d, nil) + pipe, err := winpipe.DialTimeout(pipePath, 10*time.Millisecond) + if err == nil { + pipe.Close() + } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } @@ -114,7 +117,10 @@ func TestDialContextListenerTimesOut(t *testing.T) { defer l.Close() d := 10 * time.Millisecond ctx, _ := context.WithTimeout(context.Background(), d) - _, err = winpipe.DialContext(ctx, pipePath, nil) + pipe, err := winpipe.DialContext(ctx, pipePath) + if err == nil { + pipe.Close() + } if err != context.DeadlineExceeded { t.Fatalf("expected context.DeadlineExceeded, got %v", err) } @@ -127,10 +133,10 @@ func TestDialListenerGetsCancelled(t *testing.T) { if err != nil { t.Fatal(err) } - ch := make(chan error) defer l.Close() + ch := make(chan error) go func(ctx context.Context, ch chan error) { - _, err := winpipe.DialContext(ctx, pipePath, nil) + _, err := winpipe.DialContext(ctx, pipePath) ch <- err }(ctx, ch) time.Sleep(time.Millisecond * 30) @@ -155,7 +161,10 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { t.Fatal(err) } defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) + pipe, err := winpipe.DialTimeout(pipePath, time.Duration(0)) + if err == nil { + pipe.Close() + } if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) } @@ -179,7 +188,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, ch <- response{c, err} }() - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { return } @@ -245,7 +254,7 @@ func TestFullListenDialReadWrite(t *testing.T) { ch := make(chan int) go server(l, ch) - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -374,7 +383,10 @@ func TestDialTimesOutByDefault(t *testing.T) { t.Fatal(err) } defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) + pipe, err := winpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. + if err == nil { + pipe.Close() + } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } @@ -400,7 +412,7 @@ func TestTimeoutPendingRead(t *testing.T) { close(serverDone) }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -448,7 +460,7 @@ func TestTimeoutPendingWrite(t *testing.T) { close(serverDone) }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -496,19 +508,21 @@ func TestEchoWithMessaging(t *testing.T) { clientDone := make(chan bool) go func() { // server echo - conn, e := l.Accept() - if e != nil { - t.Fatal(e) + conn, err := l.Accept() + if err != nil { + t.Fatal(err) } defer conn.Close() time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent - io.Copy(conn, conn) + _, err = io.Copy(conn, conn) + if err != nil { + t.Fatal(err) + } conn.(CloseWriter).CloseWrite() close(listenerDone) }() - timeout := 1 * time.Second - client, err := winpipe.Dial(pipePath, &timeout, nil) + client, err := winpipe.DialTimeout(pipePath, time.Second) if err != nil { t.Fatal(err) } @@ -521,7 +535,7 @@ func TestEchoWithMessaging(t *testing.T) { if e != nil { t.Fatal(e) } - if n != 2 { + if n != 2 || bytes[0] != 0 || bytes[1] != 1 { t.Fatalf("expected 2 bytes, got %v", n) } close(clientDone) @@ -565,7 +579,7 @@ func TestConnectRace(t *testing.T) { }() for i := 0; i < 1000; i++ { - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -602,7 +616,7 @@ func TestMessageReadMode(t *testing.T) { s.Close() }() - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -643,7 +657,7 @@ func TestListenConnectRace(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := winpipe.DialTimeout(pipePath, time.Duration(0)) if err == nil { c.Close() } |