diff options
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet_test.go | 43 |
2 files changed, 63 insertions, 1 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index df8bf435d..2153222cf 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -16,6 +16,7 @@ package gonet import ( + "context" "errors" "io" "net" @@ -495,6 +496,12 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { // DialTCP creates a new TCP Conn connected to the specified address. func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { + return DialContextTCP(context.Background(), s, addr, network) +} + +// DialContextTCP creates a new TCP Conn connected to the specified address +// with the option of adding cancellation and timeouts. +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { // Create TCP endpoint, then connect. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) @@ -509,9 +516,21 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc wq.EventRegister(&waitEntry, waiter.EventOut) defer wq.EventUnregister(&waitEntry) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + err = ep.Connect(addr) if err == tcpip.ErrConnectStarted { - <-notifyCh + select { + case <-ctx.Done(): + ep.Close() + return nil, ctx.Err() + case <-notifyCh: + } + err = ep.GetSockOpt(tcpip.ErrorOption{}) } if err != nil { diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 2c81c5697..2552004a9 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -15,6 +15,7 @@ package gonet import ( + "context" "fmt" "io" "net" @@ -595,6 +596,48 @@ func TestTCPDialError(t *testing.T) { } } +func TestDialContextTCPCanceled(t *testing.T) { + s, err := newLoopbackStack() + if err != nil { + t.Fatalf("newLoopbackStack() = %v", err) + } + + addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + cancel() + + if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled { + t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled) + } +} + +func TestDialContextTCPTimeout(t *testing.T) { + s, err := newLoopbackStack() + if err != nil { + t.Fatalf("newLoopbackStack() = %v", err) + } + + addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + + fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { + time.Sleep(time.Second) + r.Complete(true) + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) + + ctx := context.Background() + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) + defer cancel() + + if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded { + t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded) + } +} + func TestNetTest(t *testing.T) { nettest.TestConn(t, makePipe) } |