diff options
Diffstat (limited to 'pkg/tcpip/adapters/gonet')
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet_test.go | 18 |
2 files changed, 21 insertions, 19 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 7c7495c30..c188aaa18 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -248,7 +248,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { func (l *TCPListener) Accept() (net.Conn, error) { n, wq, err := l.ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) l.wq.EventRegister(&waitEntry, waiter.EventIn) @@ -257,7 +257,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { for { n, wq, err = l.ep.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } @@ -298,14 +298,14 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} res, err := ep.Read(&w, opts) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { res, err = ep.Read(&w, opts) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } select { @@ -316,7 +316,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s } } - if err == tcpip.ErrClosedForReceive { + if _, ok := err.(*tcpip.ErrClosedForReceive); ok { return 0, io.EOF } @@ -356,7 +356,7 @@ func (c *TCPConn) Write(b []byte) (int, error) { } // We must handle two soft failure conditions simultaneously: - // 1. Write may write nothing and return tcpip.ErrWouldBlock. + // 1. Write may write nothing and return *tcpip.ErrWouldBlock. // If this happens, we need to register for notifications if we have // not already and wait to try again. // 2. Write may write fewer than the full number of bytes and return @@ -376,9 +376,9 @@ func (c *TCPConn) Write(b []byte) (int, error) { r.Reset(b[nbytes:]) n, err := c.ep.Write(&r, tcpip.WriteOptions{}) nbytes += int(n) - switch err { + switch err.(type) { case nil: - case tcpip.ErrWouldBlock: + case *tcpip.ErrWouldBlock: if ch == nil { entry, ch = waiter.NewChannelEntry(nil) @@ -495,7 +495,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, } err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); ok { select { case <-ctx.Done(): ep.Close() @@ -649,7 +649,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { var r bytes.Reader r.Reset(b) n, err := c.ep.Write(&r, writeOptions) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.wq.EventRegister(&waitEntry, waiter.EventOut) @@ -662,7 +662,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } n, err = c.ep.Write(&r, writeOptions) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index b196324c7..2b3ea4bdf 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -58,7 +58,7 @@ func TestTimeouts(t *testing.T) { } } -func newLoopbackStack() (*stack.Stack, *tcpip.Error) { +func newLoopbackStack() (*stack.Stack, tcpip.Error) { // Create the stack and add a NIC. s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, @@ -94,7 +94,7 @@ type testConnection struct { ep tcpip.Endpoint } -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { +func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { @@ -105,7 +105,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er wq.EventRegister(&entry, waiter.EventOut) err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); ok { <-ch err = ep.LastError() } @@ -660,11 +660,13 @@ func TestTCPDialError(t *testing.T) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - _, err := DialTCP(s, addr, ipv4.ProtocolNumber) - got, ok := err.(*net.OpError) - want := tcpip.ErrNoRoute - if !ok || got.Err.Error() != want.String() { - t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute) + switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) { + case *net.OpError: + if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() { + t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{}) + } + default: + t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{}) } } |