diff options
author | Ian Gudger <igudger@google.com> | 2019-05-07 14:26:24 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-05-07 14:27:36 -0700 |
commit | 20862f0db27efac0eed3bb23d01b22b09bddfa27 (patch) | |
tree | 2bb9450cb50fc0e307826c7be94de2571daac940 /pkg/tcpip/adapters/gonet/gonet.go | |
parent | e5432fa1b365edcebf9c8c01e2c40ade3014f282 (diff) |
Add gonet.DialContextTCP.
Allows cancellation and timeouts.
PiperOrigin-RevId: 247090428
Change-Id: I91907f12e218677dcd0e0b6d72819deedbd9f20c
Diffstat (limited to 'pkg/tcpip/adapters/gonet/gonet.go')
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index df8bf435d..2153222cf 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -16,6 +16,7 @@ package gonet import ( + "context" "errors" "io" "net" @@ -495,6 +496,12 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { // DialTCP creates a new TCP Conn connected to the specified address. func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { + return DialContextTCP(context.Background(), s, addr, network) +} + +// DialContextTCP creates a new TCP Conn connected to the specified address +// with the option of adding cancellation and timeouts. +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { // Create TCP endpoint, then connect. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) @@ -509,9 +516,21 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc wq.EventRegister(&waitEntry, waiter.EventOut) defer wq.EventUnregister(&waitEntry) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + err = ep.Connect(addr) if err == tcpip.ErrConnectStarted { - <-notifyCh + select { + case <-ctx.Done(): + ep.Close() + return nil, ctx.Err() + case <-notifyCh: + } + err = ep.GetSockOpt(tcpip.ErrorOption{}) } if err != nil { |