summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go21
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go43
2 files changed, 63 insertions, 1 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index df8bf435d..2153222cf 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -16,6 +16,7 @@
package gonet
import (
+ "context"
"errors"
"io"
"net"
@@ -495,6 +496,12 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
// DialTCP creates a new TCP Conn connected to the specified address.
func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+ return DialContextTCP(context.Background(), s, addr, network)
+}
+
+// DialContextTCP creates a new TCP Conn connected to the specified address
+// with the option of adding cancellation and timeouts.
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -509,9 +516,21 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc
wq.EventRegister(&waitEntry, waiter.EventOut)
defer wq.EventUnregister(&waitEntry)
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
err = ep.Connect(addr)
if err == tcpip.ErrConnectStarted {
- <-notifyCh
+ select {
+ case <-ctx.Done():
+ ep.Close()
+ return nil, ctx.Err()
+ case <-notifyCh:
+ }
+
err = ep.GetSockOpt(tcpip.ErrorOption{})
}
if err != nil {
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 2c81c5697..2552004a9 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -15,6 +15,7 @@
package gonet
import (
+ "context"
"fmt"
"io"
"net"
@@ -595,6 +596,48 @@ func TestTCPDialError(t *testing.T) {
}
}
+func TestDialContextTCPCanceled(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ ctx := context.Background()
+ ctx, cancel := context.WithCancel(ctx)
+ cancel()
+
+ if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled {
+ t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled)
+ }
+}
+
+func TestDialContextTCPTimeout(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ time.Sleep(time.Second)
+ r.Complete(true)
+ })
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ ctx := context.Background()
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond))
+ defer cancel()
+
+ if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded {
+ t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded)
+ }
+}
+
func TestNetTest(t *testing.T) {
nettest.TestConn(t, makePipe)
}