diff options
Diffstat (limited to 'pkg/tcpip/transport')
30 files changed, 1280 insertions, 1082 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 256e19296..3cf05520d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -69,8 +69,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int + mu sync.RWMutex `state:"nosave"` // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState @@ -85,7 +84,7 @@ type endpoint struct { ops tcpip.SocketOptions } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -94,11 +93,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.SetSendBufferSize(32*1024, false /* notify */) + + // Override with stack defaults. + var ss tcpip.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) + } return ep, nil } @@ -119,7 +124,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -154,14 +159,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -188,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil @@ -199,7 +204,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.state { case stateInitial: case stateConnected: @@ -207,11 +212,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi case stateBound: if to == nil { - return false, tcpip.ErrDestinationRequired + return false, &tcpip.ErrDestinationRequired{} } return false, nil default: - return false, tcpip.ErrInvalidEndpointState + return false, &tcpip.ErrInvalidEndpointState{} } e.mu.RUnlock() @@ -236,18 +241,18 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -257,10 +262,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } to := opts.To @@ -270,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } // Prepare for write. @@ -292,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc nicID := to.NIC if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } nicID = e.BindNICID @@ -313,11 +318,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc route = r } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } + var err tcpip.Error switch e.NetProto { case header.IPv4ProtocolNumber: err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) @@ -334,12 +340,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.TTLOption: e.mu.Lock() @@ -351,7 +357,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -362,11 +368,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSize - e.mu.Unlock() - return v, nil case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() @@ -381,18 +382,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -410,7 +411,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi // Linux performs these basic checks. if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } icmpv4.SetChecksum(0) @@ -424,9 +425,9 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -441,7 +442,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } dataVV := data.ToVectorisedView() @@ -456,7 +457,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { return tcpip.FullAddress{}, 0, err @@ -465,12 +466,12 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect connects the endpoint to its peer. Specifying a NIC is optional. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -485,12 +486,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } if nicID != 0 && nicID != e.BindNICID { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nicID = e.BindNICID default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -535,19 +536,19 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection // to its peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() e.shutdownFlags |= flags if e.state != stateConnected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } if flags&tcpip.ShutdownRead != 0 { @@ -565,31 +566,31 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen is not supported by UDP, it just fails. -func (*endpoint) Listen(int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) - switch err { + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) + switch err.(type) { case nil: return true, nil - case tcpip.ErrPortInUse: + case *tcpip.ErrPortInUse: return false, nil default: return false, err @@ -599,11 +600,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ return id, err } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. if e.state != stateInitial { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -619,7 +620,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { if len(addr.Addr) != 0 { // A local address was specified, verify that it's valid. if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } @@ -647,7 +648,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -663,7 +664,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -675,12 +676,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() if e.state != stateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return tcpip.FullAddress{ @@ -805,7 +806,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { func (*endpoint) Wait() {} // LastError implements tcpip.Endpoint.LastError. -func (*endpoint) LastError() *tcpip.Error { +func (*endpoint) LastError() tcpip.Error { return nil } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 9d263c0ec..c9fa9974a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -69,12 +69,13 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) if e.state != stateBound && e.state != stateConnected { return } - var err *tcpip.Error + var err tcpip.Error if e.state == stateConnected { e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) if err != nil { @@ -84,7 +85,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.ID.LocalAddress = e.route.LocalAddress } else if len(e.ID.LocalAddress) != 0 { // stateBound if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3820e5dc7..47f7dd1cb 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -59,18 +59,18 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if netProto != p.netProto() { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if netProto != p.netProto() { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } @@ -87,7 +87,7 @@ func (p *protocol) MinimumPacketSize() int { } // ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil. -func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { switch p.number { case ProtocolNumber4: hdr := header.ICMPv4(v) @@ -106,13 +106,13 @@ func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stac } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c0d6fb442..73bb66830 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -79,24 +79,22 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool - boundNIC tcpip.NICID + mu sync.RWMutex `state:"nosave"` + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID // lastErrorMu protects lastError. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // ops is used to get socket level options. ops tcpip.SocketOptions } // NewEndpoint returns a new packet endpoint. -func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -106,14 +104,13 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb netProto: netProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - ep.sndBufSizeMax = ss.Default + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -162,16 +159,16 @@ func (ep *endpoint) Close() { func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. if ep.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if ep.rcvClosed { ep.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } ep.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -201,49 +198,49 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul n, err := packet.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be // disconnected, and this function always returns tpcip.ErrNotSupported. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be -// connected, and this function always returnes tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return tcpip.ErrNotSupported +// connected, and this function always returnes *tcpip.ErrNotSupported. +func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used -// with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - return tcpip.ErrNotSupported +// with Shutdown, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with -// Listen, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrNotSupported +// Listen, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Listen(backlog int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with -// Accept, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +// Accept, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } // Bind implements tcpip.Endpoint.Bind. -func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { // TODO(gvisor.dev/issue/173): Add Bind support. // "By default, all packets of the specified protocol type are passed @@ -277,14 +274,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, tcpip.ErrNotSupported +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { // Even a connected socket doesn't return a remote address. - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness implements tcpip.Endpoint.Readiness. @@ -306,38 +303,20 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be // used with SetSockOpt, and this function always returns -// tcpip.ErrNotSupported. -func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +// *tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := ep.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - ep.mu.Lock() - ep.sndBufSizeMax = v - ep.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -357,11 +336,11 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (ep *endpoint) LastError() *tcpip.Error { +func (ep *endpoint) LastError() tcpip.Error { ep.lastErrorMu.Lock() defer ep.lastErrorMu.Unlock() @@ -371,19 +350,19 @@ func (ep *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (ep *endpoint) UpdateLastError(err *tcpip.Error) { +func (ep *endpoint) UpdateLastError(err tcpip.Error) { ep.lastErrorMu.Lock() ep.lastError = err ep.lastErrorMu.Unlock() } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrNotSupported +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrNotSupported{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -395,12 +374,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - ep.mu.Lock() - v := ep.sndBufSizeMax - ep.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: ep.rcvMu.Lock() v := ep.rcvBufSizeMax @@ -408,7 +381,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index e2fa96d17..ece662c0d 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -63,29 +63,11 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. ep.stack = stack.StackFromEnv + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(*err) + panic(err) } } - -// saveLastError is invoked by stateify. -func (ep *endpoint) saveLastError() string { - if ep.lastError == nil { - return "" - } - - return ep.lastError.String() -} - -// loadLastError is invoked by stateify. -func (ep *endpoint) loadLastError(s string) { - if s == "" { - return - } - - ep.lastError = tcpip.StringToError(s) -} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ae743f75e..9c9ccc0ff 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -76,12 +76,10 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - connected bool - bound bool + mu sync.RWMutex `state:"nosave"` + closed bool + connected bool + bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. route *stack.Route `state:"manual"` @@ -95,13 +93,13 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } e := &endpoint{ @@ -112,16 +110,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, associated: associated, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) e.ops.SetHeaderIncluded(!associated) + e.ops.SetSendBufferSize(32*1024, false /* notify */) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -138,7 +136,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } - if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { return nil, err } @@ -159,7 +157,7 @@ func (e *endpoint) Close() { return } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.rcvMu.Lock() defer e.rcvMu.Unlock() @@ -191,16 +189,16 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -227,37 +225,37 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil } // Write implements tcpip.Endpoint.Write. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // We can create, but not write to, unassociated IPv6 endpoints. if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } } n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -267,22 +265,22 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } e.mu.RLock() defer e.mu.RUnlock() if e.closed { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } - payloadBytes, err := p.FullPayload() - if err != nil { - return 0, err + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, &tcpip.ErrBadBuffer{} } // If this is an unassociated socket and callee provided a nonzero @@ -290,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } dstAddr := ip.DestinationAddress() // Update dstAddr with the address in the IP header, unless @@ -311,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - return 0, tcpip.ErrDestinationRequired + return 0, &tcpip.ErrDestinationRequired{} } return e.finishWrite(payloadBytes, e.route) @@ -321,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } // Find the route to the destination. If BindAddress is 0, @@ -338,7 +336,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) { +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, tcpip.Error) { if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), @@ -365,22 +363,22 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect implements tcpip.Endpoint.Connect. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { - return tcpip.ErrAddressFamilyNotSupported + return &tcpip.ErrAddressFamilyNotSupported{} } e.mu.Lock() defer e.mu.Unlock() if e.closed { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nic := addr.NIC @@ -395,7 +393,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } else if addr.NIC != e.BindNICID { // We're bound and addr specifies a NIC. They must be // the same. - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } } @@ -407,15 +405,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { route.Release() return err } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.RegisterNICID = nic } - // Save the route we've connected via. + if e.route != nil { + // If the endpoint was previously connected then release any previous route. + e.route.Release() + } e.route = route e.connected = true @@ -423,42 +424,42 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() if !e.connected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } return nil } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(backlog int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept implements tcpip.Endpoint.Accept. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } // Bind implements tcpip.Endpoint.Bind. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If a local address was specified, verify that it's valid. - if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { - return tcpip.ErrBadLocalAddress + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 { + return &tcpip.ErrBadLocalAddress{} } if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { return err } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.RegisterNICID = addr.NIC e.BindNICID = addr.NIC } @@ -470,14 +471,14 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, tcpip.ErrNotSupported +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { // Even a connected socket doesn't return a remote address. - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness implements tcpip.Endpoint.Readiness. @@ -498,37 +499,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -548,17 +531,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -570,12 +553,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax @@ -583,7 +560,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } @@ -703,7 +680,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { func (*endpoint) Wait() {} // LastError implements tcpip.Endpoint.LastError. -func (*endpoint) LastError() *tcpip.Error { +func (*endpoint) LastError() tcpip.Error { return nil } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 4a7e1c039..263ec5146 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -69,10 +69,11 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) // If the endpoint is connected, re-connect. if e.connected { - var err *tcpip.Error + var err tcpip.Error // TODO(gvisor.dev/issue/4906): Properly restore the route with the right // remote address. We used to pass e.remote.RemoteAddress which was // effectively the empty address but since moving e.route to hold a pointer @@ -88,12 +89,12 @@ func (e *endpoint) Resume(s *stack.Stack) { // If the endpoint is bound, re-bind. if e.bound { if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } if e.associated { - if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index f30aa2a4a..e393b993d 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -25,11 +25,11 @@ import ( type EndpointFactory struct{} // NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. -func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } // NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. -func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7e81203ba..fcdd032c5 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -99,7 +99,6 @@ go_test( "//pkg/rand", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6921de0f1..842c1622b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,7 +199,7 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // On success, a handshake h is returned with h.ep.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { +func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -267,7 +267,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.mu.Unlock() ep.Close() - return nil, tcpip.ErrConnectionAborted + return nil, &tcpip.ErrConnectionAborted{} } l.addPendingEndpoint(ep) @@ -281,14 +281,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q l.removePendingEndpoint(ep) - return nil, tcpip.ErrConnectionAborted + return nil, &tcpip.ErrConnectionAborted{} } deferAccept = l.listenEP.deferAccept } // Register new endpoint so that packets are routed to it. - if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { ep.mu.Unlock() ep.Close() @@ -313,7 +313,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // established endpoint is returned with e.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { +func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) { h, err := l.startHandshake(s, opts, queue, owner) if err != nil { return nil, err @@ -467,7 +467,7 @@ func (e *endpoint) notifyAborted() { // cookies to accept connections. // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error { defer s.decRef() h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) @@ -522,7 +522,7 @@ func (e *endpoint) acceptQueueIsFull() bool { // and needs to handle it. // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() @@ -692,7 +692,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er } // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { n.mu.Unlock() n.Close() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index f711cd4df..4695b66d6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -226,7 +226,7 @@ func (h *handshake) checkAck(s *segment) bool { // synSentState handles a segment received when the TCP 3-way handshake is in // the SYN-SENT state. -func (h *handshake) synSentState(s *segment) *tcpip.Error { +func (h *handshake) synSentState(s *segment) tcpip.Error { // RFC 793, page 37, states that in the SYN-SENT state, a reset is // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { @@ -237,7 +237,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.ep.workerCleanup = true // Although the RFC above calls out ECONNRESET, Linux actually returns // ECONNREFUSED here so we do as well. - return tcpip.ErrConnectionRefused + return &tcpip.ErrConnectionRefused{} } return nil } @@ -314,12 +314,12 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // synRcvdState handles a segment received when the TCP 3-way handshake is in // the SYN-RCVD state. -func (h *handshake) synRcvdState(s *segment) *tcpip.Error { +func (h *handshake) synRcvdState(s *segment) tcpip.Error { if s.flagIsSet(header.TCPFlagRst) { // RFC 793, page 37, states that in the SYN-RCVD state, a reset // is acceptable if the sequence number is in the window. if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { - return tcpip.ErrConnectionRefused + return &tcpip.ErrConnectionRefused{} } return nil } @@ -349,7 +349,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) if !h.active { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } h.resetState() @@ -412,7 +412,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } -func (h *handshake) handleSegment(s *segment) *tcpip.Error { +func (h *handshake) handleSegment(s *segment) tcpip.Error { h.sndWnd = s.window if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) @@ -429,7 +429,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error { // processSegments goes through the segment queue and processes up to // maxSegmentsPerWake (if they're available). -func (h *handshake) processSegments() *tcpip.Error { +func (h *handshake) processSegments() tcpip.Error { for i := 0; i < maxSegmentsPerWake; i++ { s := h.ep.segmentQueue.dequeue() if s == nil { @@ -505,7 +505,7 @@ func (h *handshake) start() { } // complete completes the TCP 3-way handshake initiated by h.start(). -func (h *handshake) complete() *tcpip.Error { +func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper resendWaker := sleep.Waker{} @@ -555,7 +555,7 @@ func (h *handshake) complete() *tcpip.Error { case wakerForNotification: n := h.ep.fetchNotifications() if (n¬ifyClose)|(n¬ifyAbort) != 0 { - return tcpip.ErrAborted + return &tcpip.ErrAborted{} } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { @@ -593,19 +593,19 @@ type backoffTimer struct { t *time.Timer } -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { if timeout > maxTimeout { - return nil, tcpip.ErrTimeout + return nil, &tcpip.ErrTimeout{} } bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} bt.t = time.AfterFunc(timeout, f) return bt, nil } -func (bt *backoffTimer) reset() *tcpip.Error { +func (bt *backoffTimer) reset() tcpip.Error { bt.timeout *= 2 if bt.timeout > MaxRTO { - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } bt.t.Reset(bt.timeout) return nil @@ -706,7 +706,7 @@ type tcpFields struct { txHash uint32 } -func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error { tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { @@ -716,7 +716,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp return nil } -func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error { tf.txHash = e.txHash if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() @@ -755,7 +755,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { // We need to shallow clone the VectorisedView here as ReadToView will // split the VectorisedView and Trim underlying views as it splits. Not // doing the clone here will cause the underlying views of data itself @@ -803,7 +803,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > math.MaxUint16 { tf.rcvWnd = math.MaxUint16 @@ -875,7 +875,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendRaw sends a TCP segment to the endpoint's peer. -func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] @@ -895,55 +895,60 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn return err } -func (e *endpoint) handleWrite() *tcpip.Error { - // Move packets from send queue to send list. The queue is accessible - // from other goroutines and protected by the send mutex, while the send - // list is only accessible from the handler goroutine, so it needs no - // mutexes. +func (e *endpoint) handleWrite() { e.sndBufMu.Lock() + next := e.drainSendQueueLocked() + e.sndBufMu.Unlock() + + e.sendData(next) +} +// Move packets from send queue to send list. +// +// Precondition: e.sndBufMu must be locked. +func (e *endpoint) drainSendQueueLocked() *segment { first := e.sndQueue.Front() if first != nil { e.snd.writeList.PushBackList(&e.sndQueue) e.sndBufInQueue = 0 } + return first +} - e.sndBufMu.Unlock() - +// Precondition: e.mu must be locked. +func (e *endpoint) sendData(next *segment) { // Initialize the next segment to write if it's currently nil. if e.snd.writeNext == nil { - e.snd.writeNext = first + e.snd.writeNext = next } // Push out any new packets. e.snd.sendData() - - return nil } -func (e *endpoint) handleClose() *tcpip.Error { +func (e *endpoint) handleClose() { if !e.EndpointState().connected() { - return nil + return } // Drain the send queue. e.handleWrite() // Mark send side as closed. e.snd.closed = true - - return nil } // resetConnectionLocked puts the endpoint in an error state with the given // error code and sends a RST if and only if the error is not ErrConnectionReset // indicating that the connection is being reset due to receiving a RST. This // method must only be called from the protocol goroutine. -func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { +func (e *endpoint) resetConnectionLocked(err tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.setEndpointState(StateError) e.hardError = err - if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + switch err.(type) { + case *tcpip.ErrConnectionReset, *tcpip.ErrTimeout: + default: // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk // which can cause sndNxt to be outside the acceptable window on the @@ -1053,7 +1058,7 @@ func (e *endpoint) drainClosingSegmentQueue() { } } -func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { +func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states // except SYN-SENT, all reset (RST) segments are @@ -1081,7 +1086,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // delete the TCB, and return. case StateCloseWait: e.transitionToStateCloseLocked() - e.hardError = tcpip.ErrAborted + e.hardError = &tcpip.ErrAborted{} e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: @@ -1094,14 +1099,14 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // handleSegment is invoked from the processor goroutine // rather than the worker goroutine. e.notifyProtocolGoroutine(notifyResetByPeer) - return false, tcpip.ErrConnectionReset + return false, &tcpip.ErrConnectionReset{} } } return true, nil } // handleSegments processes all inbound segments. -func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { +func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { if e.EndpointState().closed() { @@ -1148,7 +1153,7 @@ func (e *endpoint) probeSegment() { // handleSegment handles a given segment and notifies the worker goroutine if // if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { +func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { // Invoke the tcp probe if installed. The tcp probe function will update // the TCPEndpointState after the segment is processed. defer e.probeSegment() @@ -1222,7 +1227,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. -func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { +func (e *endpoint) keepaliveTimerExpired() tcpip.Error { userTimeout := e.userTimeout e.keepalive.Lock() @@ -1236,13 +1241,13 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with @@ -1286,7 +1291,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { e.mu.Lock() var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1332,6 +1337,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } + // Reaching this point means that we successfully completed the 3-way + // handshake with our peer. + // + // Completing the 3-way handshake is an indication that the route is valid + // and the remote is reachable as the only way we can complete a handshake + // is if our SYN reached the remote and their ACK reached us. + e.route.ConfirmReachable() + drained := e.drainDone != nil if drained { close(e.drainDone) @@ -1344,19 +1357,25 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // wakes up. funcs := []struct { w *sleep.Waker - f func() *tcpip.Error + f func() tcpip.Error }{ { w: &e.sndWaker, - f: e.handleWrite, + f: func() tcpip.Error { + e.handleWrite() + return nil + }, }, { w: &e.sndCloseWaker, - f: e.handleClose, + f: func() tcpip.Error { + e.handleClose() + return nil + }, }, { w: &closeWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { // This means the socket is being closed due // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. @@ -1367,10 +1386,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.snd.resendWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { if !e.snd.retransmitTimerExpired() { e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } return nil }, @@ -1381,7 +1400,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.newSegmentWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { return e.handleSegments(false /* fastPath */) }, }, @@ -1391,7 +1410,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.notificationWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { n := e.fetchNotifications() if n¬ifyNonZeroReceiveWindow != 0 { e.rcv.nonZeroWindow() @@ -1408,11 +1427,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyReset != 0 || n¬ifyAbort != 0 { - return tcpip.ErrConnectionAborted + return &tcpip.ErrConnectionAborted{} } if n¬ifyResetByPeer != 0 { - return tcpip.ErrConnectionReset + return &tcpip.ErrConnectionReset{} } if n¬ifyClose != 0 && closeTimer == nil { @@ -1491,7 +1510,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. - cleanupOnError := func(err *tcpip.Error) { + cleanupOnError := func(err tcpip.Error) { e.stack.Stats().TCP.CurrentConnected.Decrement() e.workerCleanup = true if err != nil { diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 1d1b01a6c..2d90246e4 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -15,11 +15,11 @@ package tcp_test import ( + "strings" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -37,7 +37,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { // Start connection attempt, it must fail. err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrNoRoute { + if _, ok := err.(*tcpip.ErrNoRoute); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } } @@ -49,7 +49,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -156,7 +156,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -391,7 +391,7 @@ func testV4Accept(t *testing.T, c *context.Context) { defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -415,8 +415,10 @@ func testV4Accept(t *testing.T, c *context.Context) { t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) } + var r strings.Reader data := "Don't panic" - nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + r.Reset(data) + nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() tcp = header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { @@ -523,7 +525,7 @@ func TestV6AcceptOnV6(t *testing.T) { defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress _, _, err := c.EP.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -547,7 +549,7 @@ func TestV4AcceptOnV4(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) @@ -611,7 +613,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -633,7 +635,7 @@ func TestV4ListenCloseOnV4(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ea509ac73..6e4e26c39 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -386,12 +386,12 @@ type endpoint struct { // hardError is meaningful only when state is stateError. It stores the // error to be returned when read/write syscalls are called and the // endpoint is in this state. hardError is protected by endpoint mu. - hardError *tcpip.Error `state:".(string)"` + hardError tcpip.Error // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // rcvReadMu synchronizes calls to Read. // @@ -557,7 +557,6 @@ type endpoint struct { // When the send side is closed, the protocol goroutine is notified via // sndCloseWaker, and sndClosed is set to true. sndBufMu sync.Mutex `state:"nosave"` - sndBufSize int sndBufUsed int sndClosed bool sndBufInQueue seqnum.Size @@ -869,7 +868,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, - sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), keepalive: keepalive{ // Linux defaults. @@ -882,13 +880,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) + e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - e.sndBufSize = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -967,7 +966,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() - if e.sndClosed || e.sndBufUsed < e.sndBufSize { + sndBufSize := e.getSendBufferSize() + if e.sndClosed || e.sndBufUsed < sndBufSize { result |= waiter.EventOut } e.sndBufMu.Unlock() @@ -1059,7 +1059,7 @@ func (e *endpoint) Close() { if isResetState { // Close the endpoint without doing full shutdown and // send a RST. - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) e.closeNoShutdownLocked() // Wake up worker to close the endpoint. @@ -1087,7 +1087,7 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1161,7 +1161,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1293,14 +1293,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Preconditions: e.mu must be held to call this function. -func (e *endpoint) hardErrorLocked() *tcpip.Error { +func (e *endpoint) hardErrorLocked() tcpip.Error { err := e.hardError e.hardError = nil return err } // Preconditions: e.mu must be held to call this function. -func (e *endpoint) lastErrorLocked() *tcpip.Error { +func (e *endpoint) lastErrorLocked() tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1309,7 +1309,7 @@ func (e *endpoint) lastErrorLocked() *tcpip.Error { } // LastError implements tcpip.Endpoint.LastError. -func (e *endpoint) LastError() *tcpip.Error { +func (e *endpoint) LastError() tcpip.Error { e.LockUser() defer e.UnlockUser() if err := e.hardErrorLocked(); err != nil { @@ -1319,7 +1319,7 @@ func (e *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (e *endpoint) UpdateLastError(err *tcpip.Error) { +func (e *endpoint) UpdateLastError(err tcpip.Error) { e.LockUser() e.lastErrorMu.Lock() e.lastError = err @@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvReadMu.Lock() defer e.rcvReadMu.Unlock() @@ -1337,7 +1337,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // can remove segments from the list through commitRead(). first, last, serr := e.startRead() if serr != nil { - if serr == tcpip.ErrClosedForReceive { + if _, ok := serr.(*tcpip.ErrClosedForReceive); ok { e.stats.ReadErrors.ReadClosed.Increment() } return tcpip.ReadResult{}, serr @@ -1377,7 +1377,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // If something is read, we must report it. Report error when nothing is read. if done == 0 && err != nil { - return tcpip.ReadResult{}, tcpip.ErrBadBuffer + return tcpip.ReadResult{}, &tcpip.ErrBadBuffer{} } return tcpip.ReadResult{ Count: done, @@ -1389,7 +1389,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // inclusive range of segments that can be read. // // Precondition: e.rcvReadMu must be held. -func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { +func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1398,7 +1398,7 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { // on a receive. It can expect to read any data after the handshake // is complete. RFC793, section 3.9, p58. if e.EndpointState() == StateSynSent { - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } // The endpoint can be read if it's connected, or if it's already closed @@ -1414,17 +1414,17 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { if err := e.hardErrorLocked(); err != nil { return nil, nil, err } - return nil, nil, tcpip.ErrClosedForReceive + return nil, nil, &tcpip.ErrClosedForReceive{} } e.stats.ReadErrors.NotConnected.Increment() - return nil, nil, tcpip.ErrNotConnected + return nil, nil, &tcpip.ErrNotConnected{} } if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { - return nil, nil, tcpip.ErrClosedForReceive + return nil, nil, &tcpip.ErrClosedForReceive{} } - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } return e.rcvList.Front(), e.rcvList.Back(), nil @@ -1476,106 +1476,117 @@ func (e *endpoint) commitRead(done int) *segment { // moment. If the endpoint is not writable then it returns an error // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu -func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { +func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: if err := e.hardErrorLocked(); err != nil { return 0, err } - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} case !s.connecting() && !s.connected(): - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} case s.connecting(): // As per RFC793, page 56, a send request arriving when in connecting // state, can be queued to be completed after the state becomes // connected. Return an error code for the caller of endpoint Write to // try again, until the connection handshake is complete. - return 0, tcpip.ErrWouldBlock + return 0, &tcpip.ErrWouldBlock{} } // Check if the connection has already been closed for sends. if e.sndClosed { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } - avail := e.sndBufSize - e.sndBufUsed + sndBufSize := e.getSendBufferSize() + avail := sndBufSize - e.sndBufUsed if avail <= 0 { - return 0, tcpip.ErrWouldBlock + return 0, &tcpip.ErrWouldBlock{} } return avail, nil } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. e.LockUser() - e.sndBufMu.Lock() - - avail, err := e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, err - } - - // We can release locks while copying data. - // - // This is not possible if atomic is set, because we can't allow the - // available buffer space to be consumed by some other caller while we - // are copying data in. - if !opts.Atomic { - e.sndBufMu.Unlock() - e.UnlockUser() - } - - // Fetch data. - v, perr := p.Payload(avail) - if perr != nil || len(v) == 0 { - // Note that perr may be nil if len(v) == 0. - if opts.Atomic { - e.sndBufMu.Unlock() - e.UnlockUser() - } - return 0, perr - } + defer e.UnlockUser() - if !opts.Atomic { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - e.LockUser() + nextSeg, n, err := func() (*segment, int, tcpip.Error) { e.sndBufMu.Lock() + defer e.sndBufMu.Unlock() + avail, err := e.isEndpointWritableLocked() if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() - return 0, err + return nil, 0, err } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + v, err := func() ([]byte, tcpip.Error) { + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + defer e.sndBufMu.Lock() + + e.UnlockUser() + defer e.LockUser() + } + + // Fetch data. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return nil, nil + } + v := make([]byte, avail) + if _, err := io.ReadFull(p, v); err != nil { + return nil, &tcpip.ErrBadBuffer{} + } + return v, nil + }() + if len(v) == 0 || err != nil { + return nil, 0, err + } + + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.stats.WriteErrors.WriteClosed.Increment() + return nil, 0, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } - } - // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() + // Add data to the send queue. + s := newOutgoingSegment(e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) - // Do the work inline. - e.handleWrite() - e.UnlockUser() - return int64(len(v)), nil + return e.drainSendQueueLocked(), len(v), nil + }() + if err != nil { + return 0, err + } + e.sendData(nextSeg) + return int64(n), nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1682,8 +1693,16 @@ func (e *endpoint) OnCorkOptionSet(v bool) { } } +func (e *endpoint) getSendBufferSize() int { + sndBufSize, err := e.ops.GetSendBufferSize() + if err != nil { + panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err)) + } + return int(sndBufSize) +} + // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 const inetECNMask = 3 @@ -1711,7 +1730,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.MaxSegOption: userMSS := v if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } e.LockUser() e.userMSS = uint16(userMSS) @@ -1722,7 +1741,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Return not supported if attempting to set this option to // anything other than path MTU discovery disabled. if v != tcpip.PMTUDiscoveryDont { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } case tcpip.ReceiveBufferSizeOption: @@ -1775,31 +1794,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss tcpip.TCPSendBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) - } - - if v > ss.Max { - v = ss.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < ss.Min { - v = ss.Min - } - } else { - v = math.MaxInt32 - } - - e.sndBufMu.Lock() - e.sndBufSize = v - e.sndBufMu.Unlock() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1807,7 +1801,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.TCPSynCountOption: if v < 1 || v > 255 { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } e.LockUser() e.maxSynRetries = uint8(v) @@ -1823,7 +1817,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: e.UnlockUser() - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -1844,7 +1838,7 @@ func (e *endpoint) HasNIC(id int32) bool { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() @@ -1890,7 +1884,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. - return tcpip.ErrNoSuchFile + return &tcpip.ErrNoSuchFile{} case *tcpip.TCPLingerTimeoutOption: e.LockUser() @@ -1933,13 +1927,13 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } // readyReceiveSize returns the number of bytes ready to be received. -func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { +func (e *endpoint) readyReceiveSize() (int, tcpip.Error) { e.LockUser() defer e.UnlockUser() // The endpoint cannot be in listen state. if e.EndpointState() == StateListen { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } e.rcvListMu.Lock() @@ -1949,7 +1943,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.KeepaliveCountOption: e.keepalive.Lock() @@ -1985,12 +1979,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - v := e.sndBufSize - e.sndBufMu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvListMu.Lock() v := e.rcvBufSize @@ -2019,24 +2007,38 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return 1, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} + } +} + +func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { + info := tcpip.TCPInfoOption{} + e.LockUser() + snd := e.snd + if snd != nil { + // We do not calculate RTT before sending the data packets. If + // the connection did not send and receive data, then RTT will + // be zero. + snd.rtt.Lock() + info.RTT = snd.rtt.srtt + info.RTTVar = snd.rtt.rttvar + snd.rtt.Unlock() + + info.RTO = snd.rto + info.CcState = snd.state + info.SndSsthresh = uint32(snd.sndSsthresh) + info.SndCwnd = uint32(snd.sndCwnd) + info.ReorderSeen = snd.rc.reorderSeen } + e.UnlockUser() + return info } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.TCPInfoOption: - *o = tcpip.TCPInfoOption{} - e.LockUser() - snd := e.snd - e.UnlockUser() - if snd != nil { - snd.rtt.Lock() - o.RTT = snd.rtt.srtt - o.RTTVar = snd.rtt.rttvar - snd.rtt.Unlock() - } + *o = e.getTCPInfo() case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() @@ -2082,14 +2084,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } return nil } // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -2098,18 +2100,20 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect connects the endpoint to its peer. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { err := e.connect(addr, true, true) - if err != nil && !err.IgnoreStats() { - // Connect failed. Let's wake up any waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() + if err != nil { + if !err.IgnoreStats() { + // Connect failed. Let's wake up any waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } } return err } @@ -2120,7 +2124,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -2139,7 +2143,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return nil } // Otherwise return that it's already connected. - return tcpip.ErrAlreadyConnected + return &tcpip.ErrAlreadyConnected{} } nicID := addr.NIC @@ -2152,7 +2156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if nicID != 0 && nicID != e.boundNICID { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } nicID = e.boundNICID @@ -2164,16 +2168,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc case StateConnecting, StateSynSent, StateSynRecv: // A connection request has already been issued but hasn't completed // yet. - return tcpip.ErrAlreadyConnecting + return &tcpip.ErrAlreadyConnecting{} case StateError: if err := e.hardErrorLocked(); err != nil { return err } - return tcpip.ErrConnectionAborted + return &tcpip.ErrConnectionAborted{} default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // Find a route to the desired destination. @@ -2190,7 +2194,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -2229,12 +2233,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { - if err != tcpip.ErrPortInUse || !reuse { + if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } transEPID := e.ID @@ -2278,9 +2282,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) - if err == tcpip.ErrPortInUse { + if _, ok := err.(*tcpip.ErrPortInUse); ok { return false, nil } return false, err @@ -2335,23 +2339,23 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } - return tcpip.ErrConnectStarted + return &tcpip.ErrConnectStarted{} } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection to its // peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() return e.shutdownLocked(flags) } -func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.shutdownFlags |= flags switch { case e.EndpointState().connected(): @@ -2366,7 +2370,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) // Wake up worker to terminate loop. e.notifyProtocolGoroutine(notifyTickleWorker) return nil @@ -2380,7 +2384,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // Already closed. e.sndBufMu.Unlock() if e.EndpointState() == StateTimeWait { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } return nil } @@ -2413,22 +2417,24 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } return nil default: - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } } // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) tcpip.Error { err := e.listen(backlog) - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() + if err != nil { + if !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } } return err } -func (e *endpoint) listen(backlog int) *tcpip.Error { +func (e *endpoint) listen(backlog int) tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -2446,7 +2452,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // Adjust the size of the channel iff we can fix // existing pending connections into the new one. if len(e.acceptedChan) > backlog { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } if cap(e.acceptedChan) == backlog { return nil @@ -2478,11 +2484,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // Endpoint must be bound before it can transition to listen mode. if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } @@ -2518,7 +2524,7 @@ func (e *endpoint) startAcceptedLoop() { // to an endpoint previously set to listen mode. // // addr if not-nil will contain the peer address of the returned endpoint. -func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2527,7 +2533,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. e.rcvListMu.Unlock() // Endpoint must be in listen state before it can accept connections. if rcvClosed || e.EndpointState() != StateListen { - return nil, nil, tcpip.ErrInvalidEndpointState + return nil, nil, &tcpip.ErrInvalidEndpointState{} } // Get the new accepted endpoint. @@ -2538,7 +2544,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. case n = <-e.acceptedChan: e.acceptCond.Signal() default: - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } if peerAddr != nil { *peerAddr = n.getRemoteAddress() @@ -2547,19 +2553,19 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. } // Bind binds the endpoint to a specific local port and optionally address. -func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { +func (e *endpoint) Bind(addr tcpip.FullAddress) (err tcpip.Error) { e.LockUser() defer e.UnlockUser() return e.bindLocked(addr) } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. if e.EndpointState() != StateInitial { - return tcpip.ErrAlreadyBound + return &tcpip.ErrAlreadyBound{} } e.BindAddr = addr.Addr @@ -2587,7 +2593,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { if len(addr.Addr) != 0 { nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nic == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } e.ID.LocalAddress = addr.Addr } @@ -2604,7 +2610,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2628,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2640,12 +2646,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.LockUser() defer e.UnlockUser() if !e.EndpointState().connected() { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return e.getRemoteAddress(), nil @@ -2677,7 +2683,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } -func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -2726,26 +2732,27 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlAddressUnreachable: - e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNetworkUnreachable{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } // updateSndBufferUsage is called by the protocol goroutine when room opens up // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { + sendBufferSize := e.getSendBufferSize() e.sndBufMu.Lock() - notify := e.sndBufUsed >= e.sndBufSize>>1 + notify := e.sndBufUsed >= sendBufferSize>>1 e.sndBufUsed -= v - // We only notify when there is half the sndBufSize available after + // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < e.sndBufSize>>1 + notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 e.sndBufMu.Unlock() if notify { @@ -2957,8 +2964,9 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() // Copy endpoint send state. + sndBufSize := e.getSendBufferSize() e.sndBufMu.Lock() - s.SndBufSize = e.sndBufSize + s.SndBufSize = sndBufSize s.SndBufUsed = e.sndBufUsed s.SndClosed = e.sndClosed s.SndBufInQueue = e.sndBufInQueue @@ -3023,12 +3031,16 @@ func (e *endpoint) completeState() stack.TCPEndpointState { rc := &e.snd.rc s.Sender.RACKState = stack.TCPRACKState{ - XmitTime: rc.xmitTime, - EndSequence: rc.endSequence, - FACK: rc.fack, - RTT: rc.rtt, - Reord: rc.reorderSeen, - DSACKSeen: rc.dsackSeen, + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + Reord: rc.reorderSeen, + DSACKSeen: rc.dsackSeen, + ReoWnd: rc.reoWnd, + ReoWndIncr: rc.reoWndIncr, + ReoWndPersist: rc.reoWndPersist, + RTTSeq: rc.rttSeq, } return s } @@ -3103,3 +3115,17 @@ func (e *endpoint) Wait() { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// GetTCPSendBufferLimits is used to get send buffer size limits for TCP. +func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { + var ss tcpip.TCPSendBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.SendBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ba67176b5..c21dbc682 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,9 +55,11 @@ func (e *endpoint) beforeSave() { case epState.connected() || epState.handshake(): if !e.route.HasSaveRestoreCapability() { if !e.route.HasDisconncetOkCapability() { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) + panic(&tcpip.ErrSaveRejection{ + Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), + }) } - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) e.mu.Unlock() e.Close() e.mu.Lock() @@ -179,14 +181,16 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + sendBufferSize := e.getSendBufferSize() + if sendBufferSize < ss.Min || sendBufferSize > ss.Max { + panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max)) } } @@ -228,7 +232,8 @@ func (e *endpoint) Resume(s *stack.Stack) { // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } e.mu.Lock() @@ -265,7 +270,8 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { + err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -292,24 +298,6 @@ func (e *endpoint) Resume(s *stack.Stack) { } } -// saveLastError is invoked by stateify. -func (e *endpoint) saveLastError() string { - if e.lastError == nil { - return "" - } - - return e.lastError.String() -} - -// loadLastError is invoked by stateify. -func (e *endpoint) loadLastError(s string) { - if s == "" { - return - } - - e.lastError = tcpip.StringToError(s) -} - // saveRecentTSTime is invoked by stateify. func (e *endpoint) saveRecentTSTime() unixTime { return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} @@ -320,24 +308,6 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { e.recentTSTime = time.Unix(unix.second, unix.nano) } -// saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { - return "" - } - - return e.hardError.String() -} - -// loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { - if s == "" { - return - } - - e.hardError = tcpip.StringToError(s) -} - // saveMeasureTime is invoked by stateify. func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 596178625..2f9fe7ee0 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -143,12 +143,12 @@ func (r *ForwarderRequest) Complete(sendReset bool) { // CreateEndpoint creates a TCP endpoint for the connection request, performing // the 3-way handshake in the process. -func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { r.mu.Lock() defer r.mu.Unlock() if r.segment == nil { - return nil, tcpip.ErrInvalidEndpointState + return nil, &tcpip.ErrInvalidEndpointState{} } f := r.forwarder diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 1720370c9..04012cd40 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -161,13 +161,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } @@ -178,7 +178,7 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given tcp // packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { h := header.TCP(v) return h.SourcePort(), h.DestinationPort(), nil } @@ -216,7 +216,7 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, // replyWithReset replies to the given segment with a reset segment. // // If the passed TTL is 0, then the route's default TTL will be used. -func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err @@ -261,7 +261,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPSACKEnabled: p.mu.Lock() @@ -283,7 +283,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.sendBufferSize = *v @@ -292,7 +292,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.recvBufferSize = *v @@ -310,7 +310,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi } // linux returns ENOENT when an invalid congestion control // is specified. - return tcpip.ErrNoSuchFile + return &tcpip.ErrNoSuchFile{} case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() @@ -340,7 +340,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPTimeWaitReuseOption: if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.timeWaitReuse = *v @@ -381,7 +381,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPSynRetriesOption: if *v < 1 || *v > 255 { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.synRetries = uint8(*v) @@ -389,12 +389,12 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPSACKEnabled: p.mu.RLock() @@ -493,7 +493,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 307bacca5..d85cb405a 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -22,12 +22,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) -// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as -// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. -// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is -// 1, PTO is inflated by WCDelAckT time to compensate for a potential long -// delayed ACK timer at the receiver. -const wcDelayedACKTimeout = 200 * time.Millisecond +const ( + // wcDelayedACKTimeout is the recommended maximum delayed ACK timer + // value as defined in the RFC. It stands for worst case delayed ACK + // timer (WCDelAckT). When FlightSize is 1, PTO is inflated by + // WCDelAckT time to compensate for a potential long delayed ACK timer + // at the receiver. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. + wcDelayedACKTimeout = 200 * time.Millisecond + + // tcpRACKRecoveryThreshold is the number of loss recoveries for which + // the reorder window is inflated and after that the reorder window is + // reset to its initial value of minRTT/4. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2. + tcpRACKRecoveryThreshold = 16 +) // RACK is a loss detection algorithm used in TCP to detect packet loss and // reordering using transmission timestamp of the packets instead of packet or @@ -44,6 +53,11 @@ type rackControl struct { // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value + // exitedRecovery indicates if the connection is exiting loss recovery. + // This flag is set if the sender is leaving the recovery after + // receiving an ACK and is reset during updating of reorder window. + exitedRecovery bool + // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value @@ -51,15 +65,30 @@ type rackControl struct { // minRTT is the estimated minimum RTT of the connection. minRTT time.Duration + // reorderSeen indicates if reordering has been detected on this + // connection. + reorderSeen bool + + // reoWnd is the reordering window time used for recording packet + // transmission times. It is used to defer the moment at which RACK + // marks a packet lost. + reoWnd time.Duration + + // reoWndIncr is the multiplier applied to adjust reorder window. + reoWndIncr uint8 + + // reoWndPersist is the number of loss recoveries before resetting + // reorder window. + reoWndPersist int8 + // rtt is the RTT of the most recently delivered packet on the // connection (either cumulatively acknowledged or selectively // acknowledged) that was not marked invalid as a possible spurious // retransmission. rtt time.Duration - // reorderSeen indicates if reordering has been detected on this - // connection. - reorderSeen bool + // rttSeq is the SND.NXT when rtt is updated. + rttSeq seqnum.Value // xmitTime is the latest transmission timestamp of rackControl.seg. xmitTime time.Time `state:".(unixTime)"` @@ -75,29 +104,36 @@ type rackControl struct { // tlpHighRxt the value of sender.sndNxt at the time of sending // a TLP retransmission. tlpHighRxt seqnum.Value + + // snd is a reference to the sender. + snd *sender } // init initializes RACK specific fields. -func (rc *rackControl) init() { +func (rc *rackControl) init(snd *sender, iss seqnum.Value) { + rc.fack = iss + rc.reoWndIncr = 1 + rc.snd = snd rc.probeTimer.init(&rc.probeWaker) } // update will update the RACK related fields when an ACK has been received. -// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 -func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) { +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 +func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := time.Now().Sub(seg.xmitTime) + tsOffset := rc.snd.ep.tsOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: - // 1. When Timestamping option is available, if the TSVal is less than the - // transmit time of the most recent retransmitted packet. + // 1. When Timestamping option is available, if the TSVal is less than + // the transmit time of the most recent retransmitted packet. // 2. When RTT calculated for the packet is less than the smoothed RTT // for the connection. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // step 2 if seg.xmitCount > 1 { if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { - if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) { + if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) { return } } @@ -149,9 +185,8 @@ func (rc *rackControl) detectReorder(seg *segment) { } } -// setDSACKSeen updates rack control if duplicate SACK is seen by the connection. -func (rc *rackControl) setDSACKSeen() { - rc.dsackSeen = true +func (rc *rackControl) setDSACKSeen(dsackSeen bool) { + rc.dsackSeen = dsackSeen } // shouldSchedulePTO dictates whether we should schedule a PTO or not. @@ -162,7 +197,7 @@ func (s *sender) shouldSchedulePTO() bool { // The connection supports SACK. s.ep.sackPermitted && // The connection is not in loss recovery. - (s.state != RTORecovery && s.state != SACKRecovery) && + (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) && // The connection has no SACKed sequences in the SACK scoreboard. s.ep.scoreboard.Sacked() == 0 } @@ -193,7 +228,7 @@ func (s *sender) schedulePTO() { // probeTimerExpired is the same as TLP_send_probe() as defined in // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2. -func (s *sender) probeTimerExpired() *tcpip.Error { +func (s *sender) probeTimerExpired() tcpip.Error { if !s.rc.probeTimer.checkExpiration() { return nil } @@ -272,3 +307,82 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { } } } + +// updateRACKReorderWindow updates the reorder window. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +// * Step 4: Update RACK reordering window +// To handle the prevalent small degree of reordering, RACK.reo_wnd serves as +// an allowance for settling time before marking a packet lost. RACK starts +// initially with a conservative window of min_RTT/4. If no reordering has +// been observed RACK uses reo_wnd of zero during loss recovery, in order to +// retransmit quickly, or when the number of DUPACKs exceeds the classic +// DUPACKthreshold. +func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { + dsackSeen := rc.dsackSeen + snd := rc.snd + + // React to DSACK once per round trip. + // If SND.UNA < RACK.rtt_seq: + // RACK.dsack = false + if snd.sndUna.LessThan(rc.rttSeq) { + dsackSeen = false + } + + // If RACK.dsack: + // RACK.reo_wnd_incr += 1 + // RACK.dsack = false + // RACK.rtt_seq = SND.NXT + // RACK.reo_wnd_persist = 16 + if dsackSeen { + rc.reoWndIncr++ + dsackSeen = false + rc.rttSeq = snd.sndNxt + rc.reoWndPersist = tcpRACKRecoveryThreshold + } else if rc.exitedRecovery { + // Else if exiting loss recovery: + // RACK.reo_wnd_persist -= 1 + // If RACK.reo_wnd_persist <= 0: + // RACK.reo_wnd_incr = 1 + rc.reoWndPersist-- + if rc.reoWndPersist <= 0 { + rc.reoWndIncr = 1 + } + rc.exitedRecovery = false + } + + // Reorder window is zero during loss recovery, or when the number of + // DUPACKs exceeds the classic DUPACKthreshold. + // If RACK.reord is FALSE: + // If in loss recovery: (If in fast or timeout recovery) + // RACK.reo_wnd = 0 + // Return + // Else if RACK.pkts_sacked >= RACK.dupthresh: + // RACK.reo_wnd = 0 + // return + if !rc.reorderSeen { + if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery { + rc.reoWnd = 0 + return + } + + if snd.sackedOut >= nDupAckThreshold { + rc.reoWnd = 0 + return + } + } + + // Calculate reorder window. + // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr + // RACK.reo_wnd = min(RACK.reo_wnd, SRTT) + snd.rtt.Lock() + srtt := snd.rtt.srtt + snd.rtt.Unlock() + rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr)) + if srtt < rc.reoWnd { + rc.reoWnd = srtt + } +} + +func (rc *rackControl) exitRecovery() { + rc.exitedRecovery = true +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 405a6dce7..7a7c402c4 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -347,7 +347,7 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { r.ep.rcvListMu.Lock() rcvClosed := r.ep.rcvClosed || r.closed r.ep.rcvListMu.Unlock() @@ -395,7 +395,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - return true, tcpip.ErrConnectionAborted + return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { break @@ -424,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - return true, tcpip.ErrConnectionAborted + return true, &tcpip.ErrConnectionAborted{} } } @@ -443,7 +443,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { state := r.ep.EndpointState() closed := r.ep.closed diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 079d90848..dfc8fd248 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -48,28 +48,6 @@ const ( MaxRetries = 15 ) -// ccState indicates the current congestion control state for this sender. -type ccState int - -const ( - // Open indicates that the sender is receiving acks in order and - // no loss or dupACK's etc have been detected. - Open ccState = iota - // RTORecovery indicates that an RTO has occurred and the sender - // has entered an RTO based recovery phase. - RTORecovery - // FastRecovery indicates that the sender has entered FastRecovery - // based on receiving nDupAck's. This state is entered only when - // SACK is not in use. - FastRecovery - // SACKRecovery indicates that the sender has entered SACK based - // recovery. - SACKRecovery - // Disorder indicates the sender either received some SACK blocks - // or dupACK's. - Disorder -) - // congestionControl is an interface that must be implemented by any supported // congestion control algorithm. type congestionControl interface { @@ -204,7 +182,7 @@ type sender struct { maxSentAck seqnum.Value // state is the current state of congestion control for this endpoint. - state ccState + state tcpip.CongestionControlState // cc is the congestion control algorithm in use for this sender. cc congestionControl @@ -280,14 +258,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint highRxt: iss, rescueRxt: iss, }, - rc: rackControl{ - fack: iss, - }, gso: ep.gso != nil, } - s.rc.init() - if s.gso { s.ep.gso.MSS = uint16(maxPayloadSize) } @@ -295,6 +268,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.cc = s.initCongestionControl(ep.cc) s.lr = s.initLossRecovery() + s.rc.init(s, iss) // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. @@ -593,7 +567,7 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveRecovery() } - s.state = RTORecovery + s.state = tcpip.RTORecovery s.cc.HandleRTOExpired() // Mark the next segment to be sent as the first unacknowledged one and @@ -1018,7 +992,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { + if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { if s.sndCwnd > InitialCwnd { s.sndCwnd = InitialCwnd } @@ -1062,14 +1036,14 @@ func (s *sender) enterRecovery() { s.fr.highRxt = s.sndUna s.fr.rescueRxt = s.sndUna if s.ep.sackPermitted { - s.state = SACKRecovery + s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() // Set TLPRxtOut to false according to // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1. s.rc.tlpRxtOut = false return } - s.state = FastRecovery + s.state = tcpip.FastRecovery s.ep.stack.Stats().TCP.FastRecovery.Increment() } @@ -1080,7 +1054,6 @@ func (s *sender) leaveRecovery() { // Deflate cwnd. It had been artificially inflated when new dups arrived. s.sndCwnd = s.sndSsthresh - s.cc.PostRecovery() } @@ -1166,7 +1139,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { s.fr.highRxt = s.sndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() - s.state = Disorder + s.state = tcpip.Disorder return false } @@ -1217,11 +1190,13 @@ func (s *sender) isDupAck(seg *segment) bool { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { + s.rc.setDSACKSeen(false) + // Look for DSACK block. idx := 0 n := len(rcvdSeg.parsedOptions.SACKBlocks) if checkDSACK(rcvdSeg) { - s.rc.setDSACKSeen() + s.rc.setDSACKSeen(true) idx = 1 n-- } @@ -1242,7 +1217,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { for _, sb := range sackBlocks { for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { - s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) seg.acked = true s.sackedOut += s.pCount(seg, s.maxPayloadSize) @@ -1412,6 +1387,17 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { acked := s.sndUna.Size(ack) s.sndUna = ack + // The remote ACK-ing at least 1 byte is an indication that we have a + // full-duplex connection to the remote as the only way we will receive an + // ACK is if the remote received data that we previously sent. + // + // As of writing, linux seems to only confirm a route as reachable when + // forward progress is made which is indicated by an ACK that removes data + // from the retransmit queue. + if acked > 0 { + s.ep.route.ConfirmReachable() + } + ackLeft := acked originalOutstanding := s.outstanding for ackLeft > 0 { @@ -1435,7 +1421,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Update the RACK fields if SACK is enabled. if s.ep.sackPermitted && !seg.acked { - s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1464,7 +1450,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if !s.fr.active { s.cc.Update(originalOutstanding - s.outstanding) if s.fr.last.LessThan(s.sndUna) { - s.state = Open + s.state = tcpip.Open + // Update RACK when we are exiting fast or RTO + // recovery as described in the RFC + // draft-ietf-tcpm-rack-08 Section-7.2 Step 4. + s.rc.exitRecovery() } } @@ -1488,6 +1478,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } + // Update RACK reorder window. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // * Upon receiving an ACK: + // * Step 4: Update RACK reordering window + s.rc.updateRACKReorderWindow(rcvdSeg) + // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. if s.fr.active { @@ -1508,7 +1504,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // sendSegment sends the specified segment. -func (s *sender) sendSegment(seg *segment) *tcpip.Error { +func (s *sender) sendSegment(seg *segment) tcpip.Error { if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() @@ -1539,7 +1535,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. -func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index f7aaee23f..ced3a9c58 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -21,13 +21,13 @@ package tcp_test import ( + "bytes" "fmt" "math" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" @@ -42,14 +42,16 @@ func TestFastRecovery(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -207,14 +209,16 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -249,14 +253,16 @@ func TestCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -353,15 +359,16 @@ func TestCubicCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -462,19 +469,20 @@ func TestRetransmit(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in two shots. Packets will only be written at the // MTU size though. - half := data[:len(data)/2] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data[:len(data)/2]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } - half = data[len(data)/2:] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + r.Reset(data[len(data)/2:]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 342eb5eb8..a6a26b705 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -15,11 +15,12 @@ package tcp_test import ( + "bytes" + "fmt" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -61,14 +62,16 @@ func TestRACKUpdate(t *testing.T) { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(maxPayload) + data := make([]byte, maxPayload) for i := range data { data[i] = byte(i) } // Write the data. xmitTime = time.Now() - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -114,13 +117,15 @@ func TestRACKDetectReorder(t *testing.T) { }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(ackNumToVerify * maxPayload) + data := make([]byte, ackNumToVerify*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -141,17 +146,19 @@ func TestRACKDetectReorder(t *testing.T) { <-probeDone } -func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View { +func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(numPackets * maxPayload) + data := make([]byte, numPackets*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -528,3 +535,64 @@ func TestRACKWithInvalidDSACKBlock(t *testing.T) { // ACK before the test completes. <-probeDone } + +func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) { + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK detects DSACK. + n++ + if n < numACK { + return + } + + if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { + probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) + return + } + + if state.Sender.RACKState.ReoWndIncr != 1 { + probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr) + return + } + + if state.Sender.RACKState.ReoWndPersist > 0 { + probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist) + return + } + probeDone <- nil + }) +} + +func TestRACKCheckReorderWindow(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan error) + const ackNumToVerify = 3 + addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone) + + const numPackets = 7 + sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed packets [2-6] which indicates there is reordering + // in the connection. + bytesRead += 6 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + if err := <-probeDone; err != nil { + t.Fatalf("unexpected values for RACK variables: %v", err) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 6635bb815..5024bc925 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -15,6 +15,7 @@ package tcp_test import ( + "bytes" "fmt" "log" "reflect" @@ -22,7 +23,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -395,14 +395,16 @@ func TestSACKRecovery(t *testing.T) { createConnectedWithSACKAndTS(c) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 93683b921..da2730e27 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "math" + "strings" "testing" "time" @@ -26,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -48,7 +48,7 @@ type endpointTester struct { } // CheckReadError issues a read to the endpoint and checking for an error. -func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { +func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) { t.Helper() res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) if got != want { @@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for receive to be notified. select { case <-notifyRead: @@ -128,8 +128,11 @@ func TestGiveUpConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventHUp) defer wq.EventUnregister(&waitEntry) - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } // Close the connection, wait for completion. @@ -140,8 +143,11 @@ func TestGiveUpConnect(t *testing.T) { // Call Connect again to retreive the handshake failure status // and stats updates. - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrAborted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) + } } if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { @@ -194,8 +200,11 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { c.EP = ep want := stats.TCP.FailedConnectionAttempts.Value() + 1 - if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute) + { + err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) + } } if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { @@ -211,7 +220,7 @@ func TestCloseWithoutConnect(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -384,7 +393,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -925,8 +934,11 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} - if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("Connect(%+v): %s", connectAddr, err) + { + err := c.EP.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("Connect(%+v): %s", connectAddr, err) + } } // Receive SYN packet with our user supplied MSS. @@ -1347,10 +1359,9 @@ func TestTOSV4(t *testing.T) { testV4Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1396,10 +1407,9 @@ func TestTrafficClassV6(t *testing.T) { testV6Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1444,7 +1454,8 @@ func TestConnectBindToDevice(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("unexpected return value from Connect: %s", err) } @@ -1504,8 +1515,9 @@ func TestSynSent(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { - t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) + err := c.EP.Connect(addr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) } // Receive SYN packet. @@ -1550,9 +1562,9 @@ func TestSynSent(t *testing.T) { ept := endpointTester{c.EP} if test.reset { - ept.CheckReadError(t, tcpip.ErrConnectionRefused) + ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) } else { - ept.CheckReadError(t, tcpip.ErrAborted) + ept.CheckReadError(t, &tcpip.ErrAborted{}) } if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { @@ -1578,7 +1590,7 @@ func TestOutOfOrderReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send second half of data first, with seqnum 3 ahead of expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1603,7 +1615,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send the first 3 bytes now. c.SendPacket(data[:3], &context.Headers{ @@ -1642,7 +1654,7 @@ func TestOutOfOrderFlood(t *testing.T) { c.CreateConnected(789, 30000, rcvBufSz) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send 100 packets before the actual one that is expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1718,7 +1730,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1786,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1868,13 +1880,13 @@ func TestShutdownRead(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { t.Fatalf("Shutdown failed: %s", err) } - ept.CheckReadError(t, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) @@ -1893,7 +1905,7 @@ func TestFullWindowReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies // the provided buffer value by tcp.SegOverheadFactor to calculate the actual @@ -2054,7 +2066,7 @@ func TestNoWindowShrinking(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send a 1 byte payload so that we can record the current receive window. // Send a payload of half the size of rcvBufSize. @@ -2176,10 +2188,9 @@ func TestSimpleSend(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2217,10 +2228,9 @@ func TestZeroWindowSend(t *testing.T) { c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2285,10 +2295,9 @@ func TestScaledWindowConnect(t *testing.T) { }) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2317,10 +2326,9 @@ func TestNonScaledWindowConnect(t *testing.T) { c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2376,7 +2384,7 @@ func TestScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -2391,10 +2399,9 @@ func TestScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2450,7 +2457,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -2465,10 +2472,9 @@ func TestNonScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2632,9 +2638,10 @@ func TestSegmentMerging(t *testing.T) { // Send tcp.InitialCwnd number of segments to fill up // InitialWindow but don't ACK. That should prevent // anymore packets from going out. + var r bytes.Reader for i := 0; i < tcp.InitialCwnd; i++ { - view := buffer.NewViewFromBytes([]byte{0}) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset([]byte{0}) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2644,8 +2651,8 @@ func TestSegmentMerging(t *testing.T) { var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2714,8 +2721,9 @@ func TestDelay(t *testing.T) { var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2761,8 +2769,9 @@ func TestUndelay(t *testing.T) { allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2845,8 +2854,9 @@ func TestMSSNotDelayed(t *testing.T) { allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2894,10 +2904,9 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { data[i] = byte(i) } - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2963,7 +2972,7 @@ func TestSetTTL(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -2973,8 +2982,11 @@ func TestSetTTL(t *testing.T) { t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("unexpected return value from Connect: %s", err) + } } // Receive SYN packet. @@ -3034,7 +3046,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -3090,7 +3102,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -3115,9 +3127,9 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) { defer c.Cleanup() s := c.Stack() - ch := make(chan *tcpip.Error, 1) + ch := make(chan tcpip.Error, 1) f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err *tcpip.Error + var err tcpip.Error c.EP, err = r.CreateEndpoint(&c.WQ) ch <- err }) @@ -3146,7 +3158,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -3165,8 +3177,11 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventOut) defer c.WQ.EventUnregister(&we) - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } // Receive SYN packet. @@ -3276,22 +3291,23 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { - switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err { - case tcpip.ErrWouldBlock: + switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { + case *tcpip.ErrWouldBlock: select { case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset) + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrConnectionReset); !ok { + t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } break loop case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for reset to arrive") } - case tcpip.ErrConnectionReset: + case *tcpip.ErrConnectionReset: break loop default: - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } } @@ -3328,9 +3344,11 @@ func TestSendOnResetConnection(t *testing.T) { time.Sleep(1 * time.Second) // Try to write. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) + var r bytes.Reader + r.Reset(make([]byte, 10)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if _, ok := err.(*tcpip.ErrConnectionReset); !ok { + t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } } @@ -3352,7 +3370,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventHUp) defer c.WQ.EventUnregister(&waitEntry) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3409,7 +3429,9 @@ func TestMaxRTO(t *testing.T) { c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3458,7 +3480,9 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) } - if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(make([]byte, tc.size)) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } pkt := c.GetPacket() @@ -3595,8 +3619,10 @@ func TestFinWithNoPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3667,9 +3693,11 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // Write enough segments to fill the congestion window before ACK'ing // any of them. - view := buffer.NewView(10) + view := make([]byte, 10) + var r bytes.Reader for i := tcp.InitialCwnd; i > 0; i-- { - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } } @@ -3754,8 +3782,10 @@ func TestFinWithPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3781,7 +3811,8 @@ func TestFinWithPendingData(t *testing.T) { }) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3841,8 +3872,10 @@ func TestFinWithPartialAck(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3879,7 +3912,8 @@ func TestFinWithPartialAck(t *testing.T) { ) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3985,8 +4019,10 @@ func scaledSendWindow(t *testing.T, scale uint8) { }) // Send some data. Check that it's capped by the window size. - view := buffer.NewView(65535) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 65535) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4170,7 +4206,7 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -4249,10 +4285,13 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. - ept.CheckReadError(t, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) var buf bytes.Buffer - if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) + { + _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) + if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { + t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) + } } } @@ -4263,7 +4302,7 @@ func TestReusePort(t *testing.T) { defer c.Cleanup() // First case, just an endpoint that was bound. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { t.Fatalf("NewEndpoint failed; %s", err) @@ -4293,8 +4332,11 @@ func TestReusePort(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } c.EP.Close() @@ -4351,9 +4393,9 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + s, err := ep.SocketOptions().GetSendBufferSize() if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) + t.Fatalf("GetSendBufferSize failed: %s", err) } if int(s) != v { @@ -4459,9 +4501,7 @@ func TestMinMaxBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(149, true) checkSendBufferSize(t, ep, 300) @@ -4473,9 +4513,7 @@ func TestMinMaxBufferSizes(t *testing.T) { // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) @@ -4505,11 +4543,11 @@ func TestBindToDeviceOption(t *testing.T) { testActions := []struct { name string setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error + setBindToDeviceError tcpip.Error getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, {"BindToExistent", nicIDPtr(321), nil, 321}, {"UnbindToDevice", nicIDPtr(0), nil, 0}, } @@ -4529,7 +4567,7 @@ func TestBindToDeviceOption(t *testing.T) { } } -func makeStack() (*stack.Stack, *tcpip.Error) { +func makeStack() (*stack.Stack, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -4599,8 +4637,11 @@ func TestSelfConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventOut) defer wq.EventUnregister(&waitEntry) - if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } <-notifyCh @@ -4610,9 +4651,9 @@ func TestSelfConnect(t *testing.T) { // Write something. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4752,9 +4793,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("Bind(%d) failed: %s", i, err) } } - want := tcpip.ErrConnectStarted + var want tcpip.Error = &tcpip.ErrConnectStarted{} if collides { - want = tcpip.ErrNoPortAvailable + want = &tcpip.ErrNoPortAvailable{} } if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) @@ -4785,12 +4826,13 @@ func TestPathMTUDiscovery(t *testing.T) { // Send 3200 bytes of data. const writeSize = 3200 - data := buffer.NewView(writeSize) + data := make([]byte, writeSize) for i := range data { data[i] = byte(i) } - - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4878,11 +4920,11 @@ func TestTCPEndpointProbe(t *testing.T) { func TestStackSetCongestionControl(t *testing.T) { testCases := []struct { cc tcpip.CongestionControlOption - err *tcpip.Error + err tcpip.Error }{ {"reno", nil}, {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, } for _, tc := range testCases { @@ -4964,11 +5006,11 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { func TestEndpointSetCongestionControl(t *testing.T) { testCases := []struct { cc tcpip.CongestionControlOption - err *tcpip.Error + err tcpip.Error }{ {"reno", nil}, {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, } for _, connected := range []bool{false, true} { @@ -4978,7 +5020,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5074,12 +5116,14 @@ func TestKeepalive(t *testing.T) { // Check that the connection is still alive. ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5163,7 +5207,7 @@ func TestKeepalive(t *testing.T) { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) } - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) @@ -5270,7 +5314,7 @@ func TestListenBacklogFull(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5313,7 +5357,7 @@ func TestListenBacklogFull(t *testing.T) { for i := 0; i < listenBacklog; i++ { _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5330,7 +5374,7 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5342,7 +5386,7 @@ func TestListenBacklogFull(t *testing.T) { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5358,7 +5402,9 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { @@ -5583,7 +5629,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5658,7 +5704,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { defer c.WQ.EventUnregister(&we) newEP, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5674,7 +5720,9 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() tcp = header.TCP(header.IPv4(pkt).Payload()) if string(tcp.Payload()) != data { @@ -5692,7 +5740,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5733,7 +5781,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { defer c.WQ.EventUnregister(&we) _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5749,7 +5797,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5763,7 +5811,7 @@ func TestSYNRetransmit(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5807,7 +5855,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5882,12 +5930,13 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { }) newEP, _, err := c.EP.Accept(nil) - - if err != nil && err != tcpip.ErrWouldBlock { + switch err.(type) { + case nil, *tcpip.ErrWouldBlock: + default: t.Fatalf("Accept failed: %s", err) } - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -5908,7 +5957,9 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil { + var r strings.Reader + r.Reset(data) + if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5953,7 +6004,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { // Verify that there is only one acceptable connection at this point. _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6023,7 +6074,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6055,7 +6106,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } ept := endpointTester{ep} - ept.CheckReadError(t, tcpip.ErrNotConnected) + ept.CheckReadError(t, &tcpip.ErrNotConnected{}) if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) } @@ -6075,7 +6126,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { defer wq.EventUnregister(&we) aep, _, err := ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6091,8 +6142,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected) + { + err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { + t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) + } } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { @@ -6211,7 +6265,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // window increases to the full available buffer size. for { _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { break } } @@ -6335,7 +6389,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalCopied := 0 for { res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { break } totalCopied += res.Count @@ -6527,7 +6581,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6646,7 +6700,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6753,7 +6807,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6843,7 +6897,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Try to accept the connection. c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6917,7 +6971,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -7067,7 +7121,7 @@ func TestTCPCloseWithData(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -7103,10 +7157,10 @@ func TestTCPCloseWithData(t *testing.T) { // Now write a few bytes and then close the endpoint. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7204,8 +7258,10 @@ func TestTCPUserTimeout(t *testing.T) { } // Send some data and wait before ACKing it. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7256,7 +7312,7 @@ func TestTCPUserTimeout(t *testing.T) { ) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) @@ -7300,7 +7356,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { // Check that the connection is still alive. ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Now receive 1 keepalives, but don't ACK it. b := c.GetPacket() @@ -7339,7 +7395,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { ), ) - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) } @@ -7494,8 +7550,9 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) + _, _, err := c.EP.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) } // Send data. This should result in an acceptable endpoint. @@ -7552,8 +7609,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) + _, _, err := c.EP.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7675,13 +7733,13 @@ func TestSetStackTimeWaitReuse(t *testing.T) { s := c.Stack() testCases := []struct { v int - err *tcpip.Error + err tcpip.Error }{ {int(tcpip.TCPTimeWaitReuseDisabled), nil}, {int(tcpip.TCPTimeWaitReuseGlobal), nil}, {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue}, - {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, + {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, } for _, tc := range testCases { diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index b65091c3c..5a9745ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -22,7 +22,6 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -152,10 +151,10 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS // Now send some data and validate that timestamp is echoed correctly in the ACK. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } @@ -215,10 +214,10 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd // Now send some data with the accepted connection endpoint and validate // that no timestamp option is sent in the TCP segment. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ee55f030c..b1cb9a324 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -586,7 +586,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int // is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 // only endpoint instead of a default dual stack socket. func (c *Context) CreateV6Endpoint(v6only bool) { - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) @@ -689,7 +689,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { + err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { c.t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -749,7 +750,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // Create creates a TCP endpoint. func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) @@ -887,7 +888,7 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { // It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK // does not carry an option that was not requested. func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) @@ -903,7 +904,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} err = c.EP.Connect(testFullAddr) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) } // Receive SYN packet. @@ -1054,7 +1055,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9f9b3d510..31a5ddce9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -97,9 +97,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int + mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. state EndpointState @@ -111,8 +109,8 @@ type endpoint struct { multicastNICID tcpip.NICID portFlags ports.Flags - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // Values used to reserve a port or register a transport endpoint. // (which ever happens first). @@ -176,18 +174,18 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) e.ops.SetMulticastLoop(true) + e.ops.SetSendBufferSize(32*1024, false /* notify */) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -217,7 +215,7 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } -func (e *endpoint) LastError() *tcpip.Error { +func (e *endpoint) LastError() tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() @@ -227,7 +225,7 @@ func (e *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (e *endpoint) UpdateLastError(err *tcpip.Error) { +func (e *endpoint) UpdateLastError(err tcpip.Error) { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -246,7 +244,7 @@ func (e *endpoint) Close() { switch e.EndpointState() { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -284,7 +282,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err } @@ -292,10 +290,10 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult e.rcvMu.Lock() if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -342,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil @@ -353,7 +351,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.EndpointState() { case StateInitial: case StateConnected: @@ -361,11 +359,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi case StateBound: if to == nil { - return false, tcpip.ErrDestinationRequired + return false, &tcpip.ErrDestinationRequired{} } return false, nil default: - return false, tcpip.ErrInvalidEndpointState + return false, &tcpip.ErrInvalidEndpointState{} } e.mu.RUnlock() @@ -391,7 +389,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { localAddr := e.ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -417,18 +415,18 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -438,14 +436,14 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { if err := e.LastError(); err != nil { return 0, err } // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } to := opts.To @@ -461,7 +459,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } // Prepare for write. @@ -482,9 +480,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } nicID = e.BindNICID @@ -492,7 +493,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc if to.Port == 0 { // Port 0 is an invalid port to send to. - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } dst, netProto, err := e.checkV4MappedLocked(*to) @@ -511,19 +512,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, tcpip.ErrBroadcastDisabled + return 0, &tcpip.ErrBroadcastDisabled{} } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. so := e.SocketOptions() if so.GetRecvError() { so.QueueLocalErr( - tcpip.ErrMessageTooLong, + &tcpip.ErrMessageTooLong{}, route.NetProto, header.UDPMaximumPacketSize, tcpip.FullAddress{ @@ -534,7 +535,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc v, ) } - return 0, tcpip.ErrMessageTooLong + return 0, &tcpip.ErrMessageTooLong{} } ttl := e.ttl @@ -584,13 +585,13 @@ func (e *endpoint) OnReusePortSet(v bool) { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.MTUDiscoverOption: // Return not supported if the value is not disabling path // MTU discovery. if v != tcpip.PMTUDiscoveryDont { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } case tcpip.MulticastTTLOption: @@ -632,25 +633,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvBufSizeMax = v e.mu.Unlock() return nil - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) - } - - if v < ss.Min { - v = ss.Min - } - if v > ss.Max { - v = ss.Max - } - - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -661,7 +643,7 @@ func (e *endpoint) HasNIC(id int32) bool { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() @@ -683,17 +665,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { if nic != 0 { if !e.stack.CheckNIC(nic) { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } else { nic = e.stack.CheckLocalAddress(0, netProto, addr) if nic == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } if e.BindNICID != 0 && e.BindNICID != nic { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } e.multicastNICID = nic @@ -701,7 +683,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.AddMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } nicID := v.NIC @@ -717,7 +699,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} @@ -726,7 +708,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { defer e.mu.Unlock() if _, ok := e.multicastMemberships[memToInsert]; ok { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { @@ -737,7 +719,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } nicID := v.NIC @@ -752,7 +734,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} @@ -761,7 +743,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { defer e.mu.Unlock() if _, ok := e.multicastMemberships[memToRemove]; !ok { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { @@ -777,7 +759,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.IPv4TOSOption: e.mu.RLock() @@ -811,12 +793,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax @@ -830,12 +806,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() @@ -846,14 +822,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { e.mu.Unlock() default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), Data: data, @@ -903,7 +879,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -912,7 +888,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (e *endpoint) Disconnect() *tcpip.Error { +func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -930,12 +906,12 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { - var err *tcpip.Error + var err tcpip.Error id = stack.TransportEndpointID{ LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, } - id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { return err } @@ -950,7 +926,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.setEndpointState(StateInitial) } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() @@ -961,10 +937,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { } // Connect connects the endpoint to its peer. Specifying a NIC is optional. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { if addr.Port == 0 { // We don't support connecting to port zero. - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } e.mu.Lock() @@ -981,12 +957,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } if nicID != 0 && nicID != e.BindNICID { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nicID = e.BindNICID default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -1023,7 +999,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { oldPortFlags := e.boundPortFlags - id, btd, err := e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(netProtos, id) if err != nil { r.Release() return err @@ -1031,11 +1007,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) } e.ID = id e.boundBindToDevice = btd + if e.route != nil { + // If the endpoint was already connected then make sure we release the + // previous route. + e.route.Release() + } e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID @@ -1051,20 +1032,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection // to its peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. if state := e.EndpointState(); state != StateBound && state != StateConnected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } e.shutdownFlags |= flags @@ -1084,16 +1065,16 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen is not supported by UDP, it just fails. -func (*endpoint) Listen(int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) @@ -1104,7 +1085,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} @@ -1112,11 +1093,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ return id, bindToDevice, err } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. if e.EndpointState() != StateInitial { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -1140,7 +1121,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // A local unicast address was specified, verify that it's valid. nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicID == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } @@ -1148,7 +1129,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, btd, err := e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(netProtos, id) if err != nil { return err } @@ -1170,7 +1151,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -1186,7 +1167,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -1203,12 +1184,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() if e.EndpointState() != StateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return tcpip.FullAddress{ @@ -1341,7 +1322,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -1397,7 +1378,7 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt default: panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) } - e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt) + e.onICMPError(&tcpip.ErrConnectionRefused{}, errType, errCode, extra, pkt) return } } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 13b72dc88..21a6aa460 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,24 +37,6 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// saveLastError is invoked by stateify. -func (e *endpoint) saveLastError() string { - if e.lastError == nil { - return "" - } - - return e.lastError.String() -} - -// loadLastError is invoked by stateify. -func (e *endpoint) loadLastError(s string) { - if s == "" { - return - } - - e.lastError = tcpip.StringToError(s) -} - // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). @@ -91,6 +73,7 @@ func (e *endpoint) Resume(s *stack.Stack) { defer e.mu.Unlock() e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { @@ -113,7 +96,7 @@ func (e *endpoint) Resume(s *stack.Stack) { netProto = header.IPv6ProtocolNumber } - var err *tcpip.Error + var err tcpip.Error if state == StateConnected { e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { @@ -122,7 +105,7 @@ func (e *endpoint) Resume(s *stack.Stack) { } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } @@ -131,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // pass it to the reservation machinery. id := e.ID e.ID.LocalPort = 0 - e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 49e673d58..705ad1f64 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -69,7 +69,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { } // CreateEndpoint creates a connected UDP endpoint for the session request. -func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { netHdr := r.pkt.Network() route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) if err != nil { @@ -77,7 +77,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { + if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 91420edd3..427fdd0c9 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -54,13 +54,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } @@ -71,7 +71,7 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given udp // packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { h := header.UDP(v) return h.SourcePort(), h.DestinationPort(), nil } @@ -94,13 +94,13 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4e2123fe9..64c5298d3 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -353,7 +353,7 @@ func (c *testContext) cleanup() { func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { c.t.Helper() - var err *tcpip.Error + var err tcpip.Error c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { c.t.Fatal("NewEndpoint failed: ", err) @@ -555,11 +555,11 @@ func TestBindToDeviceOption(t *testing.T) { testActions := []struct { name string setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error + setBindToDeviceError tcpip.Error getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, {"BindToExistent", nicIDPtr(321), nil, 321}, {"UnbindToDevice", nicIDPtr(0), nil, 0}, } @@ -599,7 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe var buf bytes.Buffer res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for data to become available. select { case <-ch: @@ -703,8 +703,11 @@ func TestBindReservedPort(t *testing.T) { t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() - if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) + { + err := ep.Bind(addr) + if _, ok := err.(*tcpip.ErrPortInUse); !ok { + t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) + } } } @@ -716,8 +719,11 @@ func TestBindReservedPort(t *testing.T) { defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint // above, since the endpoint is dual-stack. - if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) + { + err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) + if _, ok := err.(*tcpip.ErrPortInUse); !ok { + t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) + } } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { @@ -806,11 +812,11 @@ func TestV4ReadSelfSource(t *testing.T) { for _, tt := range []struct { name string handleLocal bool - wantErr *tcpip.Error + wantErr tcpip.Error wantInvalidSource uint64 }{ {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, + {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ @@ -959,15 +965,16 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { // testFailingWrite sends a packet of the given test flow into the UDP endpoint // and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { +func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) { c.t.Helper() // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - payload := buffer.View(newPayload()) - _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + var r bytes.Reader + r.Reset(newPayload()) + _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) c.checkEndpointWriteStats(1, epstats, gotErr) @@ -1007,8 +1014,10 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, } } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("Write failed: %s", err) } @@ -1089,7 +1098,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { testWrite(c, unicastV6) // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{}) const want = 1 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) @@ -1110,7 +1119,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { testWrite(c, unicastV4in6) // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) + testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) } func TestV4WriteOnV6Only(t *testing.T) { @@ -1120,7 +1129,7 @@ func TestV4WriteOnV6Only(t *testing.T) { c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) + testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{}) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { @@ -1135,7 +1144,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) + testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) } func TestV6WriteOnConnected(t *testing.T) { @@ -1183,8 +1192,10 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { writeOpts := tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) } @@ -1192,8 +1203,11 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) } - if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused { - c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + { + err := c.ep.LastError() + if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { + c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + } } }) } @@ -2303,21 +2317,21 @@ func TestShutdownWrite(t *testing.T) { t.Fatalf("Shutdown failed: %s", err) } - testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) + testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{}) } -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { + switch err.(type) { case nil: want.PacketsSent.IncrementBy(incr) - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: want.WriteErrors.InvalidArgs.IncrementBy(incr) - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: want.WriteErrors.WriteClosed.IncrementBy(incr) - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: want.SendErrors.NoRoute.IncrementBy(incr) default: want.SendErrors.SendToNetworkFailed.IncrementBy(incr) @@ -2327,11 +2341,11 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE } } -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil, tcpip.ErrWouldBlock: - case tcpip.ErrClosedForReceive: + switch err.(type) { + case nil, *tcpip.ErrWouldBlock: + case *tcpip.ErrClosedForReceive: want.ReadErrors.ReadClosed.IncrementBy(incr) default: c.t.Errorf("Endpoint error missing stats update err %v", err) @@ -2497,31 +2511,54 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } defer ep.Close() - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + var r bytes.Reader + data := []byte{1, 2, 3, 4} to := tcpip.FullAddress{ Addr: test.remoteAddr, Port: 80, } opts := tcpip.WriteOptions{To: &to} - expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { + if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { + return nil + } + return &tcpip.ErrBroadcastDisabled{} + } if !test.requiresBroadcastOpt { expectedErrWithoutBcastOpt = nil } - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) + r.Reset(data) + { + n, err := ep.Write(&r, opts) + if expectedErrWithoutBcastOpt != nil { + if want := expectedErrWithoutBcastOpt(err); want != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) + } + } else if err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) + } } ep.SocketOptions().SetBroadcast(true) - if n, err := ep.Write(data, opts); err != nil { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != nil { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) } ep.SocketOptions().SetBroadcast(false) - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) + r.Reset(data) + { + n, err := ep.Write(&r, opts) + if expectedErrWithoutBcastOpt != nil { + if want := expectedErrWithoutBcastOpt(err); want != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) + } + } else if err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) + } } }) } |