summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/adapters/gonet
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/adapters/gonet')
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go22
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go18
2 files changed, 21 insertions, 19 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 7c7495c30..c188aaa18 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -248,7 +248,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
l.wq.EventRegister(&waitEntry, waiter.EventIn)
@@ -257,7 +257,7 @@ func (l *TCPListener) Accept() (net.Conn, error) {
for {
n, wq, err = l.ep.Accept(nil)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
@@ -298,14 +298,14 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
res, err := ep.Read(&w, opts)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
res, err = ep.Read(&w, opts)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
select {
@@ -316,7 +316,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
}
}
- if err == tcpip.ErrClosedForReceive {
+ if _, ok := err.(*tcpip.ErrClosedForReceive); ok {
return 0, io.EOF
}
@@ -356,7 +356,7 @@ func (c *TCPConn) Write(b []byte) (int, error) {
}
// We must handle two soft failure conditions simultaneously:
- // 1. Write may write nothing and return tcpip.ErrWouldBlock.
+ // 1. Write may write nothing and return *tcpip.ErrWouldBlock.
// If this happens, we need to register for notifications if we have
// not already and wait to try again.
// 2. Write may write fewer than the full number of bytes and return
@@ -376,9 +376,9 @@ func (c *TCPConn) Write(b []byte) (int, error) {
r.Reset(b[nbytes:])
n, err := c.ep.Write(&r, tcpip.WriteOptions{})
nbytes += int(n)
- switch err {
+ switch err.(type) {
case nil:
- case tcpip.ErrWouldBlock:
+ case *tcpip.ErrWouldBlock:
if ch == nil {
entry, ch = waiter.NewChannelEntry(nil)
@@ -495,7 +495,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
}
err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); ok {
select {
case <-ctx.Done():
ep.Close()
@@ -649,7 +649,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
var r bytes.Reader
r.Reset(b)
n, err := c.ep.Write(&r, writeOptions)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.wq.EventRegister(&waitEntry, waiter.EventOut)
@@ -662,7 +662,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
n, err = c.ep.Write(&r, writeOptions)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index b196324c7..2b3ea4bdf 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -58,7 +58,7 @@ func TestTimeouts(t *testing.T) {
}
}
-func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
+func newLoopbackStack() (*stack.Stack, tcpip.Error) {
// Create the stack and add a NIC.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
@@ -94,7 +94,7 @@ type testConnection struct {
ep tcpip.Endpoint
}
-func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
+func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) {
wq := &waiter.Queue{}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
@@ -105,7 +105,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er
wq.EventRegister(&entry, waiter.EventOut)
err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); ok {
<-ch
err = ep.LastError()
}
@@ -660,11 +660,13 @@ func TestTCPDialError(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- _, err := DialTCP(s, addr, ipv4.ProtocolNumber)
- got, ok := err.(*net.OpError)
- want := tcpip.ErrNoRoute
- if !ok || got.Err.Error() != want.String() {
- t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
+ switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) {
+ case *net.OpError:
+ if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() {
+ t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{})
+ }
+ default:
+ t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{})
}
}