diff options
Diffstat (limited to 'pkg/tcpip')
133 files changed, 7999 insertions, 4254 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index e7924e5c2..f979d22f0 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "tcpip", srcs = [ + "errors.go", "sock_err_list.go", "socketops.go", "tcpip.go", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index fdeec12d3..c188aaa18 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -16,6 +16,7 @@ package gonet import ( + "bytes" "context" "errors" "io" @@ -247,7 +248,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { func (l *TCPListener) Accept() (net.Conn, error) { n, wq, err := l.ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) l.wq.EventRegister(&waitEntry, waiter.EventIn) @@ -256,7 +257,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { for { n, wq, err = l.ep.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } @@ -297,14 +298,14 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} res, err := ep.Read(&w, opts) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { res, err = ep.Read(&w, opts) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } select { @@ -315,7 +316,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s } } - if err == tcpip.ErrClosedForReceive { + if _, ok := err.(*tcpip.ErrClosedForReceive); ok { return 0, io.EOF } @@ -354,10 +355,8 @@ func (c *TCPConn) Write(b []byte) (int, error) { default: } - v := buffer.NewViewFromBytes(b) - // We must handle two soft failure conditions simultaneously: - // 1. Write may write nothing and return tcpip.ErrWouldBlock. + // 1. Write may write nothing and return *tcpip.ErrWouldBlock. // If this happens, we need to register for notifications if we have // not already and wait to try again. // 2. Write may write fewer than the full number of bytes and return @@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) { // There is no guarantee that all of the condition #1s will occur before // all of the condition #2s or visa-versa. var ( - err *tcpip.Error - nbytes int - reg bool - notifyCh chan struct{} + r bytes.Reader + nbytes int + entry waiter.Entry + ch <-chan struct{} ) - for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { - if err == tcpip.ErrWouldBlock { - if !reg { - // Only register once. - reg = true - - // Create wait queue entry that notifies a channel. - var waitEntry waiter.Entry - waitEntry, notifyCh = waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) + for nbytes != len(b) { + r.Reset(b[nbytes:]) + n, err := c.ep.Write(&r, tcpip.WriteOptions{}) + nbytes += int(n) + switch err.(type) { + case nil: + case *tcpip.ErrWouldBlock: + if ch == nil { + entry, ch = waiter.NewChannelEntry(nil) + + c.wq.EventRegister(&entry, waiter.EventOut) + defer c.wq.EventUnregister(&entry) } else { // Don't wait immediately after registration in case more data // became available between when we last checked and when we setup @@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) { select { case <-deadline: return nbytes, c.newOpError("write", &timeoutError{}) - case <-notifyCh: + case <-ch: + continue } } + default: + return nbytes, c.newOpError("write", errors.New(err.String())) } - - var n int64 - n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } - - if err == nil { - return nbytes, nil } - - return nbytes, c.newOpError("write", errors.New(err.String())) + return nbytes, nil } // Close implements net.Conn.Close. @@ -502,7 +495,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, } err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); ok { select { case <-ctx.Done(): ep.Close() @@ -644,17 +637,19 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // If we're being called by Write, there is no addr - wopts := tcpip.WriteOptions{} + writeOptions := tcpip.WriteOptions{} if addr != nil { ua := addr.(*net.UDPAddr) - wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} + writeOptions.To = &tcpip.FullAddress{ + Addr: tcpip.Address(ua.IP), + Port: uint16(ua.Port), + } } - v := buffer.NewView(len(b)) - copy(v, b) - - n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) - if err == tcpip.ErrWouldBlock { + var r bytes.Reader + r.Reset(b) + n, err := c.ep.Write(&r, writeOptions) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.wq.EventRegister(&waitEntry, waiter.EventOut) @@ -666,8 +661,8 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { case <-notifyCh: } - n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) - if err != tcpip.ErrWouldBlock { + n, err = c.ep.Write(&r, writeOptions) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index b196324c7..2b3ea4bdf 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -58,7 +58,7 @@ func TestTimeouts(t *testing.T) { } } -func newLoopbackStack() (*stack.Stack, *tcpip.Error) { +func newLoopbackStack() (*stack.Stack, tcpip.Error) { // Create the stack and add a NIC. s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, @@ -94,7 +94,7 @@ type testConnection struct { ep tcpip.Endpoint } -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { +func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { @@ -105,7 +105,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er wq.EventRegister(&entry, waiter.EventOut) err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); ok { <-ch err = ep.LastError() } @@ -660,11 +660,13 @@ func TestTCPDialError(t *testing.T) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - _, err := DialTCP(s, addr, ipv4.ProtocolNumber) - got, ok := err.(*net.OpError) - want := tcpip.ErrNoRoute - if !ok || got.Err.Error() != want.String() { - t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute) + switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) { + case *net.OpError: + if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() { + t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{}) + } + default: + t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{}) } } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 0ac2000ca..07b4393a4 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -234,7 +234,7 @@ func IPv4RouterAlert() NetworkChecker { for { opt, done, err := iterator.Next() if err != nil { - t.Fatalf("error acquiring next IPv4 option %s", err) + t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer) } if done { break diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go new file mode 100644 index 000000000..af46da1d2 --- /dev/null +++ b/pkg/tcpip/errors.go @@ -0,0 +1,538 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "fmt" +) + +// Error represents an error in the netstack error space. +// +// The error interface is intentionally omitted to avoid loss of type +// information that would occur if these errors were passed as error. +type Error interface { + isError() + + // IgnoreStats indicates whether this error should be included in failure + // counts in tcpip.Stats structs. + IgnoreStats() bool + + fmt.Stringer +} + +// ErrAborted indicates the operation was aborted. +// +// +stateify savable +type ErrAborted struct{} + +func (*ErrAborted) isError() {} + +// IgnoreStats implements Error. +func (*ErrAborted) IgnoreStats() bool { + return false +} +func (*ErrAborted) String() string { + return "operation aborted" +} + +// ErrAddressFamilyNotSupported indicates the operation does not support the +// given address family. +// +// +stateify savable +type ErrAddressFamilyNotSupported struct{} + +func (*ErrAddressFamilyNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrAddressFamilyNotSupported) IgnoreStats() bool { + return false +} +func (*ErrAddressFamilyNotSupported) String() string { + return "address family not supported by protocol" +} + +// ErrAlreadyBound indicates the endpoint is already bound. +// +// +stateify savable +type ErrAlreadyBound struct{} + +func (*ErrAlreadyBound) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyBound) IgnoreStats() bool { + return true +} +func (*ErrAlreadyBound) String() string { return "endpoint already bound" } + +// ErrAlreadyConnected indicates the endpoint is already connected. +// +// +stateify savable +type ErrAlreadyConnected struct{} + +func (*ErrAlreadyConnected) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyConnected) IgnoreStats() bool { + return true +} +func (*ErrAlreadyConnected) String() string { return "endpoint is already connected" } + +// ErrAlreadyConnecting indicates the endpoint is already connecting. +// +// +stateify savable +type ErrAlreadyConnecting struct{} + +func (*ErrAlreadyConnecting) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyConnecting) IgnoreStats() bool { + return true +} +func (*ErrAlreadyConnecting) String() string { return "endpoint is already connecting" } + +// ErrBadAddress indicates a bad address was provided. +// +// +stateify savable +type ErrBadAddress struct{} + +func (*ErrBadAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadAddress) IgnoreStats() bool { + return false +} +func (*ErrBadAddress) String() string { return "bad address" } + +// ErrBadBuffer indicates a bad buffer was provided. +// +// +stateify savable +type ErrBadBuffer struct{} + +func (*ErrBadBuffer) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadBuffer) IgnoreStats() bool { + return false +} +func (*ErrBadBuffer) String() string { return "bad buffer" } + +// ErrBadLocalAddress indicates a bad local address was provided. +// +// +stateify savable +type ErrBadLocalAddress struct{} + +func (*ErrBadLocalAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadLocalAddress) IgnoreStats() bool { + return false +} +func (*ErrBadLocalAddress) String() string { return "bad local address" } + +// ErrBroadcastDisabled indicates broadcast is not enabled on the endpoint. +// +// +stateify savable +type ErrBroadcastDisabled struct{} + +func (*ErrBroadcastDisabled) isError() {} + +// IgnoreStats implements Error. +func (*ErrBroadcastDisabled) IgnoreStats() bool { + return false +} +func (*ErrBroadcastDisabled) String() string { return "broadcast socket option disabled" } + +// ErrClosedForReceive indicates the endpoint is closed for incoming data. +// +// +stateify savable +type ErrClosedForReceive struct{} + +func (*ErrClosedForReceive) isError() {} + +// IgnoreStats implements Error. +func (*ErrClosedForReceive) IgnoreStats() bool { + return false +} +func (*ErrClosedForReceive) String() string { return "endpoint is closed for receive" } + +// ErrClosedForSend indicates the endpoint is closed for outgoing data. +// +// +stateify savable +type ErrClosedForSend struct{} + +func (*ErrClosedForSend) isError() {} + +// IgnoreStats implements Error. +func (*ErrClosedForSend) IgnoreStats() bool { + return false +} +func (*ErrClosedForSend) String() string { return "endpoint is closed for send" } + +// ErrConnectStarted indicates the endpoint is connecting asynchronously. +// +// +stateify savable +type ErrConnectStarted struct{} + +func (*ErrConnectStarted) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectStarted) IgnoreStats() bool { + return true +} +func (*ErrConnectStarted) String() string { return "connection attempt started" } + +// ErrConnectionAborted indicates the connection was aborted. +// +// +stateify savable +type ErrConnectionAborted struct{} + +func (*ErrConnectionAborted) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionAborted) IgnoreStats() bool { + return false +} +func (*ErrConnectionAborted) String() string { return "connection aborted" } + +// ErrConnectionRefused indicates the connection was refused. +// +// +stateify savable +type ErrConnectionRefused struct{} + +func (*ErrConnectionRefused) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionRefused) IgnoreStats() bool { + return false +} +func (*ErrConnectionRefused) String() string { return "connection was refused" } + +// ErrConnectionReset indicates the connection was reset. +// +// +stateify savable +type ErrConnectionReset struct{} + +func (*ErrConnectionReset) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionReset) IgnoreStats() bool { + return false +} +func (*ErrConnectionReset) String() string { return "connection reset by peer" } + +// ErrDestinationRequired indicates the operation requires a destination +// address, and one was not provided. +// +// +stateify savable +type ErrDestinationRequired struct{} + +func (*ErrDestinationRequired) isError() {} + +// IgnoreStats implements Error. +func (*ErrDestinationRequired) IgnoreStats() bool { + return false +} +func (*ErrDestinationRequired) String() string { return "destination address is required" } + +// ErrDuplicateAddress indicates the operation encountered a duplicate address. +// +// +stateify savable +type ErrDuplicateAddress struct{} + +func (*ErrDuplicateAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrDuplicateAddress) IgnoreStats() bool { + return false +} +func (*ErrDuplicateAddress) String() string { return "duplicate address" } + +// ErrDuplicateNICID indicates the operation encountered a duplicate NIC ID. +// +// +stateify savable +type ErrDuplicateNICID struct{} + +func (*ErrDuplicateNICID) isError() {} + +// IgnoreStats implements Error. +func (*ErrDuplicateNICID) IgnoreStats() bool { + return false +} +func (*ErrDuplicateNICID) String() string { return "duplicate nic id" } + +// ErrInvalidEndpointState indicates the endpoint is in an invalid state. +// +// +stateify savable +type ErrInvalidEndpointState struct{} + +func (*ErrInvalidEndpointState) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidEndpointState) IgnoreStats() bool { + return false +} +func (*ErrInvalidEndpointState) String() string { return "endpoint is in invalid state" } + +// ErrInvalidOptionValue indicates an invalid option value was provided. +// +// +stateify savable +type ErrInvalidOptionValue struct{} + +func (*ErrInvalidOptionValue) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidOptionValue) IgnoreStats() bool { + return false +} +func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" } + +// ErrMalformedHeader indicates the operation encountered a malformed header. +// +// +stateify savable +type ErrMalformedHeader struct{} + +func (*ErrMalformedHeader) isError() {} + +// IgnoreStats implements Error. +func (*ErrMalformedHeader) IgnoreStats() bool { + return false +} +func (*ErrMalformedHeader) String() string { return "header is malformed" } + +// ErrMessageTooLong indicates the operation encountered a message whose length +// exceeds the maximum permitted. +// +// +stateify savable +type ErrMessageTooLong struct{} + +func (*ErrMessageTooLong) isError() {} + +// IgnoreStats implements Error. +func (*ErrMessageTooLong) IgnoreStats() bool { + return false +} +func (*ErrMessageTooLong) String() string { return "message too long" } + +// ErrNetworkUnreachable indicates the operation is not able to reach the +// destination network. +// +// +stateify savable +type ErrNetworkUnreachable struct{} + +func (*ErrNetworkUnreachable) isError() {} + +// IgnoreStats implements Error. +func (*ErrNetworkUnreachable) IgnoreStats() bool { + return false +} +func (*ErrNetworkUnreachable) String() string { return "network is unreachable" } + +// ErrNoBufferSpace indicates no buffer space is available. +// +// +stateify savable +type ErrNoBufferSpace struct{} + +func (*ErrNoBufferSpace) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoBufferSpace) IgnoreStats() bool { + return false +} +func (*ErrNoBufferSpace) String() string { return "no buffer space available" } + +// ErrNoPortAvailable indicates no port could be allocated for the operation. +// +// +stateify savable +type ErrNoPortAvailable struct{} + +func (*ErrNoPortAvailable) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoPortAvailable) IgnoreStats() bool { + return false +} +func (*ErrNoPortAvailable) String() string { return "no ports are available" } + +// ErrNoRoute indicates the operation is not able to find a route to the +// destination. +// +// +stateify savable +type ErrNoRoute struct{} + +func (*ErrNoRoute) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoRoute) IgnoreStats() bool { + return false +} +func (*ErrNoRoute) String() string { return "no route" } + +// ErrNoSuchFile is used to indicate that ENOENT should be returned the to +// calling application. +// +// +stateify savable +type ErrNoSuchFile struct{} + +func (*ErrNoSuchFile) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoSuchFile) IgnoreStats() bool { + return false +} +func (*ErrNoSuchFile) String() string { return "no such file" } + +// ErrNotConnected indicates the endpoint is not connected. +// +// +stateify savable +type ErrNotConnected struct{} + +func (*ErrNotConnected) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotConnected) IgnoreStats() bool { + return false +} +func (*ErrNotConnected) String() string { return "endpoint not connected" } + +// ErrNotPermitted indicates the operation is not permitted. +// +// +stateify savable +type ErrNotPermitted struct{} + +func (*ErrNotPermitted) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotPermitted) IgnoreStats() bool { + return false +} +func (*ErrNotPermitted) String() string { return "operation not permitted" } + +// ErrNotSupported indicates the operation is not supported. +// +// +stateify savable +type ErrNotSupported struct{} + +func (*ErrNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotSupported) IgnoreStats() bool { + return false +} +func (*ErrNotSupported) String() string { return "operation not supported" } + +// ErrPortInUse indicates the provided port is in use. +// +// +stateify savable +type ErrPortInUse struct{} + +func (*ErrPortInUse) isError() {} + +// IgnoreStats implements Error. +func (*ErrPortInUse) IgnoreStats() bool { + return false +} +func (*ErrPortInUse) String() string { return "port is in use" } + +// ErrQueueSizeNotSupported indicates the endpoint does not allow queue size +// operation. +// +// +stateify savable +type ErrQueueSizeNotSupported struct{} + +func (*ErrQueueSizeNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrQueueSizeNotSupported) IgnoreStats() bool { + return false +} +func (*ErrQueueSizeNotSupported) String() string { return "queue size querying not supported" } + +// ErrTimeout indicates the operation timed out. +// +// +stateify savable +type ErrTimeout struct{} + +func (*ErrTimeout) isError() {} + +// IgnoreStats implements Error. +func (*ErrTimeout) IgnoreStats() bool { + return false +} +func (*ErrTimeout) String() string { return "operation timed out" } + +// ErrUnknownDevice indicates an unknown device identifier was provided. +// +// +stateify savable +type ErrUnknownDevice struct{} + +func (*ErrUnknownDevice) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownDevice) IgnoreStats() bool { + return false +} +func (*ErrUnknownDevice) String() string { return "unknown device" } + +// ErrUnknownNICID indicates an unknown NIC ID was provided. +// +// +stateify savable +type ErrUnknownNICID struct{} + +func (*ErrUnknownNICID) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownNICID) IgnoreStats() bool { + return false +} +func (*ErrUnknownNICID) String() string { return "unknown nic id" } + +// ErrUnknownProtocol indicates an unknown protocol was requested. +// +// +stateify savable +type ErrUnknownProtocol struct{} + +func (*ErrUnknownProtocol) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownProtocol) IgnoreStats() bool { + return false +} +func (*ErrUnknownProtocol) String() string { return "unknown protocol" } + +// ErrUnknownProtocolOption indicates an unknown protocol option was provided. +// +// +stateify savable +type ErrUnknownProtocolOption struct{} + +func (*ErrUnknownProtocolOption) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownProtocolOption) IgnoreStats() bool { + return false +} +func (*ErrUnknownProtocolOption) String() string { return "unknown option for protocol" } + +// ErrWouldBlock indicates the operation would block. +// +// +stateify savable +type ErrWouldBlock struct{} + +func (*ErrWouldBlock) isError() {} + +// IgnoreStats implements Error. +func (*ErrWouldBlock) IgnoreStats() bool { + return true +} +func (*ErrWouldBlock) String() string { return "operation would block" } diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD index 114d43df3..bb9d44aff 100644 --- a/pkg/tcpip/faketime/BUILD +++ b/pkg/tcpip/faketime/BUILD @@ -6,10 +6,7 @@ go_library( name = "faketime", srcs = ["faketime.go"], visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "@com_github_dpjacques_clockwork//:go_default_library", - ], + deps = ["//pkg/tcpip"], ) go_test( diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index f7a4fbde1..fb819d7a8 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -17,10 +17,10 @@ package faketime import ( "container/heap" + "fmt" "sync" "time" - "github.com/dpjacques/clockwork" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -44,38 +44,85 @@ func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer { return nil } +type notificationChannels struct { + mu struct { + sync.Mutex + + ch []<-chan struct{} + } +} + +func (n *notificationChannels) add(ch <-chan struct{}) { + n.mu.Lock() + defer n.mu.Unlock() + n.mu.ch = append(n.mu.ch, ch) +} + +// wait returns once all the notification channels are readable. +// +// Channels that are added while waiting on existing channels will be waited on +// as well. +func (n *notificationChannels) wait() { + for { + n.mu.Lock() + ch := n.mu.ch + n.mu.ch = nil + n.mu.Unlock() + + if len(ch) == 0 { + break + } + + for _, c := range ch { + <-c + } + } +} + // ManualClock implements tcpip.Clock and only advances manually with Advance // method. type ManualClock struct { - clock clockwork.FakeClock + // runningTimers tracks the completion of timer callbacks that began running + // immediately upon their scheduling. It is used to ensure the proper ordering + // of timer callback dispatch. + runningTimers notificationChannels + + mu struct { + sync.RWMutex - // mu protects the fields below. - mu sync.RWMutex + // now is the current (fake) time of the clock. + now time.Time - // times is min-heap of times. A heap is used for quick retrieval of the next - // upcoming time of scheduled work. - times *timeHeap + // times is min-heap of times. + times timeHeap - // waitGroups stores one WaitGroup for all work scheduled to execute at the - // same time via AfterFunc. This allows parallel execution of all functions - // passed to AfterFunc scheduled for the same time. - waitGroups map[time.Time]*sync.WaitGroup + // timers holds the timers scheduled for each time. + timers map[time.Time]map[*manualTimer]struct{} + } } // NewManualClock creates a new ManualClock instance. func NewManualClock() *ManualClock { - return &ManualClock{ - clock: clockwork.NewFakeClock(), - times: &timeHeap{}, - waitGroups: make(map[time.Time]*sync.WaitGroup), - } + c := &ManualClock{} + + c.mu.Lock() + defer c.mu.Unlock() + + // Set the initial time to a non-zero value since the zero value is used to + // detect inactive timers. + c.mu.now = time.Unix(0, 0) + c.mu.timers = make(map[time.Time]map[*manualTimer]struct{}) + + return c } var _ tcpip.Clock = (*ManualClock)(nil) // NowNanoseconds implements tcpip.Clock.NowNanoseconds. func (mc *ManualClock) NowNanoseconds() int64 { - return mc.clock.Now().UnixNano() + mc.mu.RLock() + defer mc.mu.RUnlock() + return mc.mu.now.UnixNano() } // NowMonotonic implements tcpip.Clock.NowMonotonic. @@ -85,128 +132,203 @@ func (mc *ManualClock) NowMonotonic() int64 { // AfterFunc implements tcpip.Clock.AfterFunc. func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { - until := mc.clock.Now().Add(d) - wg := mc.addWait(until) - return &manualTimer{ + mt := &manualTimer{ clock: mc, - until: until, - timer: mc.clock.AfterFunc(d, func() { - defer wg.Done() - f() - }), + f: f, } -} -// addWait adds an additional wait to the WaitGroup for parallel execution of -// all work scheduled for t. Returns a reference to the WaitGroup modified. -func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup { - mc.mu.RLock() - wg, ok := mc.waitGroups[t] - mc.mu.RUnlock() + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() - if ok { - wg.Add(1) - return wg + mc.resetTimerLocked(mt, d) + return mt +} + +// resetTimerLocked schedules a timer to be fired after the given duration. +// +// Precondition: mc.mu and mt.mu must be locked. +func (mc *ManualClock) resetTimerLocked(mt *manualTimer, d time.Duration) { + if !mt.mu.firesAt.IsZero() { + panic("tried to reset an active timer") } - mc.mu.Lock() - heap.Push(mc.times, t) - mc.mu.Unlock() + t := mc.mu.now.Add(d) - wg = &sync.WaitGroup{} - wg.Add(1) + if !mc.mu.now.Before(t) { + // If the timer is scheduled to fire immediately, call its callback + // in a new goroutine immediately. + // + // It needs to be called in its own goroutine to escape its current + // execution context - like an actual timer. + ch := make(chan struct{}) + mc.runningTimers.add(ch) - mc.mu.Lock() - mc.waitGroups[t] = wg - mc.mu.Unlock() + go func() { + defer close(ch) + + mt.f() + }() - return wg + return + } + + mt.mu.firesAt = t + + timers, ok := mc.mu.timers[t] + if !ok { + timers = make(map[*manualTimer]struct{}) + mc.mu.timers[t] = timers + heap.Push(&mc.mu.times, t) + } + + timers[mt] = struct{}{} } -// removeWait removes a wait from the WaitGroup for parallel execution of all -// work scheduled for t. -func (mc *ManualClock) removeWait(t time.Time) { - mc.mu.RLock() - defer mc.mu.RUnlock() +// stopTimerLocked stops a timer from firing. +// +// Precondition: mc.mu and mt.mu must be locked. +func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { + t := mt.mu.firesAt + mt.mu.firesAt = time.Time{} + + if t.IsZero() { + panic("tried to stop an inactive timer") + } - wg := mc.waitGroups[t] - wg.Done() + timers, ok := mc.mu.timers[t] + if !ok { + err := fmt.Sprintf("tried to stop an active timer but the clock does not have anything scheduled for the timer @ t = %s %p\nScheduled timers @:", t.UTC(), mt) + for t := range mc.mu.timers { + err += fmt.Sprintf("%s\n", t.UTC()) + } + panic(err) + } + + if _, ok := timers[mt]; !ok { + panic(fmt.Sprintf("did not have an entry in timers for an active timer @ t = %s", t.UTC())) + } + + delete(timers, mt) + + if len(timers) == 0 { + delete(mc.mu.timers, t) + } } // Advance executes all work that have been scheduled to execute within d from -// the current time. Blocks until all work has completed execution. +// the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { - // Block until all the work is done - until := mc.clock.Now().Add(d) - for { - mc.mu.Lock() - if mc.times.Len() == 0 { - mc.mu.Unlock() - break - } + // We spawn goroutines for timers that were scheduled to fire at the time of + // being reset. Wait for those goroutines to complete before proceeding so + // that timer callbacks are called in the right order. + mc.runningTimers.wait() - t := heap.Pop(mc.times).(time.Time) + mc.mu.Lock() + defer mc.mu.Unlock() + + until := mc.mu.now.Add(d) + for mc.mu.times.Len() > 0 { + t := heap.Pop(&mc.mu.times).(time.Time) if t.After(until) { // No work to do - heap.Push(mc.times, t) - mc.mu.Unlock() + heap.Push(&mc.mu.times, t) break } - mc.mu.Unlock() - diff := t.Sub(mc.clock.Now()) - mc.clock.Advance(diff) + timers := mc.mu.timers[t] + delete(mc.mu.timers, t) + + mc.mu.now = t + + // Mark the timers as inactive since they will be fired. + // + // This needs to be done while holding mc's lock because we remove the entry + // in the map of timers for the current time. If an attempt to stop a + // timer is made after mc's lock was dropped but before the timer is + // marked inactive, we would panic since no entry exists for the time when + // the timer was expected to fire. + for mt := range timers { + mt.mu.Lock() + mt.mu.firesAt = time.Time{} + mt.mu.Unlock() + } - mc.mu.RLock() - wg := mc.waitGroups[t] - mc.mu.RUnlock() + // Release the lock before calling the timer's callback fn since the + // callback fn might try to schedule a timer which requires obtaining + // mc's lock. + mc.mu.Unlock() - wg.Wait() + for mt := range timers { + mt.f() + } + // The timer callbacks may have scheduled a timer to fire immediately. + // We spawn goroutines for these timers and need to wait for them to + // finish before proceeding so that timer callbacks are called in the + // right order. + mc.runningTimers.wait() mc.mu.Lock() - delete(mc.waitGroups, t) - mc.mu.Unlock() } - if now := mc.clock.Now(); until.After(now) { - mc.clock.Advance(until.Sub(now)) + + mc.mu.now = until +} + +func (mc *ManualClock) resetTimer(mt *manualTimer, d time.Duration) { + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() + + if !mt.mu.firesAt.IsZero() { + mc.stopTimerLocked(mt) } + + mc.resetTimerLocked(mt, d) +} + +func (mc *ManualClock) stopTimer(mt *manualTimer) bool { + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() + + if mt.mu.firesAt.IsZero() { + return false + } + + mc.stopTimerLocked(mt) + return true } type manualTimer struct { clock *ManualClock - timer clockwork.Timer + f func() - mu sync.RWMutex - until time.Time + mu struct { + sync.Mutex + + // firesAt is the time when the timer will fire. + // + // Zero only when the timer is not active. + firesAt time.Time + } } var _ tcpip.Timer = (*manualTimer)(nil) // Reset implements tcpip.Timer.Reset. -func (t *manualTimer) Reset(d time.Duration) { - if !t.timer.Reset(d) { - return - } - - t.mu.Lock() - defer t.mu.Unlock() - - t.clock.removeWait(t.until) - t.until = t.clock.clock.Now().Add(d) - t.clock.addWait(t.until) +func (mt *manualTimer) Reset(d time.Duration) { + mt.clock.resetTimer(mt, d) } // Stop implements tcpip.Timer.Stop. -func (t *manualTimer) Stop() bool { - if !t.timer.Stop() { - return false - } - - t.mu.RLock() - defer t.mu.RUnlock() - - t.clock.removeWait(t.until) - return true +func (mt *manualTimer) Stop() bool { + return mt.clock.stopTimer(mt) } type timeHeap []time.Time diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e6103f4bc..48ca60319 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package header import ( "encoding/binary" - "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -481,15 +480,13 @@ const ( IPv4OptionLengthOffset = 1 ) -// Potential errors when parsing generic IP options. -var ( - ErrIPv4OptZeroLength = errors.New("zero length IP option") - ErrIPv4OptDuplicate = errors.New("duplicate IP option") - ErrIPv4OptInvalid = errors.New("invalid IP option") - ErrIPv4OptMalformed = errors.New("malformed IP option") - ErrIPv4OptionTruncated = errors.New("truncated IP option") - ErrIPv4OptionAddress = errors.New("bad IP option address") -) +// IPv4OptParameterProblem indicates that a Parameter Problem message +// should be generated, and gives the offset in the current entity +// that should be used in that packet. +type IPv4OptParameterProblem struct { + Pointer uint8 + NeedICMP bool +} // IPv4Option is an interface representing various option types. type IPv4Option interface { @@ -583,8 +580,9 @@ func (i *IPv4OptionIterator) Finalize() IPv4Options { // It returns // - A slice of bytes holding the next option or nil if there is error. // - A boolean which is true if parsing of all the options is complete. -// - An error which is non-nil if an error condition was encountered. -func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { +// Undefined in the case of error. +// - An error indication which is non-nil if an error condition was found. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, *IPv4OptParameterProblem) { // The opts slice gets shorter as we process the options. When we have no // bytes left we are done. if len(i.options) == 0 { @@ -606,24 +604,22 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { // There are no more single byte options defined. All the rest have a length // field so we need to sanity check it. if len(i.options) == 1 { - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } optLen := i.options[IPv4OptionLengthOffset] - if optLen == 0 { - i.ErrCursor++ - return nil, true, ErrIPv4OptZeroLength - } + if optLen <= IPv4OptionLengthOffset || optLen > uint8(len(i.options)) { + // The actual error is in the length (2nd byte of the option) but we + // return the start of the option for compatibility with Linux. - if optLen == 1 { - i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed - } - - if optLen > uint8(len(i.options)) { - i.ErrCursor++ - return nil, true, ErrIPv4OptionTruncated + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } optionBody := i.options[:optLen] @@ -635,7 +631,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { case IPv4OptionTimestampType: if optLen < IPv4OptionTimestampHdrLength { i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } retval := IPv4OptionTimestamp(optionBody) return &retval, false, nil @@ -643,7 +642,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { case IPv4OptionRecordRouteType: if optLen < IPv4OptionRecordRouteHdrLength { i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } retval := IPv4OptionRecordRoute(optionBody) return &retval, false, nil diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 5580d6a78..f2403978c 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -453,9 +453,9 @@ const ( ) // ScopeForIPv6Address returns the scope for an IPv6 address. -func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { +func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) { if len(addr) != IPv6AddressSize { - return GlobalScope, tcpip.ErrBadAddress + return GlobalScope, &tcpip.ErrBadAddress{} } switch { diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index e3fbd64f3..f10f446a6 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -299,7 +299,7 @@ func TestScopeForIPv6Address(t *testing.T) { name string addr tcpip.Address scope header.IPv6AddressScope - err *tcpip.Error + err tcpip.Error }{ { name: "Unique Local", @@ -329,15 +329,15 @@ func TestScopeForIPv6Address(t *testing.T) { name: "IPv4", addr: "\x01\x02\x03\x04", scope: header.GlobalScope, - err: tcpip.ErrBadAddress, + err: &tcpip.ErrBadAddress{}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { got, err := header.ScopeForIPv6Address(test.addr) - if err != test.err { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err) + if diff := cmp.Diff(test.err, err); diff != "" { + t.Errorf("unexpected error from header.IsV6UniqueLocalAddress(%s), (-want, +got):\n%s", test.addr, diff) } if got != test.scope { t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a068d93a4..cd76272de 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -229,7 +229,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { p := PacketInfo{ Pkt: pkt, Proto: protocol, @@ -243,7 +243,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index 2f2d9d4ac..d873766a6 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -61,13 +61,13 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) return e.Endpoint.WritePacket(r, gso, proto, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 10072eac1..ae1394ebf 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -35,7 +35,6 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f86c383d8..0164d851b 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -57,7 +57,7 @@ import ( // linkDispatcher reads packets from the link FD and dispatches them to the // NetworkDispatcher. type linkDispatcher interface { - dispatch() (bool, *tcpip.Error) + dispatch() (bool, tcpip.Error) } // PacketDispatchMode are the various supported methods of receiving and @@ -118,7 +118,7 @@ type endpoint struct { // closed is a function to be called when the FD's peer (if any) closes // its end of the communication pipe. - closed func(*tcpip.Error) + closed func(tcpip.Error) inboundDispatchers []linkDispatcher dispatcher stack.NetworkDispatcher @@ -149,7 +149,7 @@ type Options struct { // ClosedFunc is a function to be called when an endpoint's peer (if // any) closes its end of the communication pipe. - ClosedFunc func(*tcpip.Error) + ClosedFunc func(tcpip.Error) // Address is the link address for this endpoint. Only used if // EthernetHeader is true. @@ -411,7 +411,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if e.hdrSize > 0 { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } @@ -451,7 +451,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } -func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { +func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) { // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { @@ -518,7 +518,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { // Preallocate to avoid repeated reallocation as we append to batch. // batchSz is 47 because when SWGSO is in use then a single 65KB TCP // segment can get split into 46 segments of 1420 bytes and a single 216 @@ -562,13 +562,13 @@ func viewsEqual(vs1, vs2 []buffer.View) bool { } // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. -func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) } // dispatchLoop reads packets from the file descriptor in a loop and dispatches // them to the network stack. -func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error { +func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { for { cont, err := inboundDispatcher.dispatch() if err != nil || !cont { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 90da22d34..e82371798 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -96,7 +95,7 @@ func newContext(t *testing.T, opt *Options) *context { } done := make(chan struct{}, 2) - opt.ClosedFunc = func(*tcpip.Error) { + opt.ClosedFunc = func(tcpip.Error) { done <- struct{}{} } @@ -465,67 +464,85 @@ var capLengthTestCases = []struct { config: []int{1, 2, 3}, n: 3, wantUsed: 2, - wantLengths: []int{1, 2, 3}, + wantLengths: []int{1, 2}, }, } -func TestReadVDispatcherCapLength(t *testing.T) { +func TestIovecBuffer(t *testing.T) { for _, c := range capLengthTestCases { - // fd does not matter for this test. - d := readVDispatcher{fd: -1, e: &endpoint{}} - d.views = make([]buffer.View, len(c.config)) - d.iovecs = make([]syscall.Iovec, len(c.config)) - d.allocateViews(c.config) - - used := d.capViews(c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views)) - for i, v := range d.views { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - } -} + t.Run(c.comment, func(t *testing.T) { + b := newIovecBuffer(c.config, false /* skipsVnetHdr */) -func TestRecvMMsgDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - d := recvMMsgDispatcher{ - fd: -1, // fd does not matter for this test. - e: &endpoint{}, - views: make([][]buffer.View, 1), - iovecs: make([][]syscall.Iovec, 1), - msgHdrs: make([]rawfile.MMsgHdr, 1), - } + // Test initial allocation. + iovecs := b.nextIovecs() + if got, want := len(iovecs), len(c.config); got != want { + t.Fatalf("len(iovecs) = %d, want %d", got, want) + } - for i := range d.views { - d.views[i] = make([]buffer.View, len(c.config)) - } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, len(c.config)) - } - for k, msgHdr := range d.msgHdrs { - msgHdr.Msg.Iov = &d.iovecs[k][0] - msgHdr.Msg.Iovlen = uint64(len(c.config)) - } + // Make a copy as iovecs points to internal slice. We will need this state + // later. + oldIovecs := append([]syscall.Iovec(nil), iovecs...) - d.allocateViews(c.config) + // Test the views that get pulled. + vv := b.pullViews(c.n) + var lengths []int + for _, v := range vv.Views() { + lengths = append(lengths, len(v)) + } + if !reflect.DeepEqual(lengths, c.wantLengths) { + t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths) + } - used := d.capViews(0, c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views[0])) - for i, v := range d.views[0] { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } + // Test that new views get reallocated. + for i, newIov := range b.nextIovecs() { + if i < c.wantUsed { + if newIov.Base == oldIovecs[i].Base { + t.Errorf("b.views[%d] should have been reallocated", i) + } + } else { + if newIov.Base != oldIovecs[i].Base { + t.Errorf("b.views[%d] should not have been reallocated", i) + } + } + } + }) + } +} +func TestIovecBufferSkipVnetHdr(t *testing.T) { + for _, test := range []struct { + desc string + readN int + wantLen int + }{ + { + desc: "nothing read", + readN: 0, + wantLen: 0, + }, + { + desc: "smaller than vnet header", + readN: virtioNetHdrSize - 1, + wantLen: 0, + }, + { + desc: "header skipped", + readN: virtioNetHdrSize + 100, + wantLen: 100, + }, + } { + t.Run(test.desc, func(t *testing.T) { + b := newIovecBuffer([]int{10, 20, 50, 50}, true) + // Pretend a read happend. + b.nextIovecs() + vv := b.pullViews(test.readN) + if got, want := vv.Size(), test.wantLen; got != want { + t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want) + } + if got, want := len(vv.ToOwnedView()), test.wantLen; got != want { + t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want) + } + }) } } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index c475dda20..a2b63fe6b 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -129,7 +129,7 @@ type packetMMapDispatcher struct { ringOffset int } -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) for hdr.tpStatus()&tpStatusUser == 0 { event := rawfile.PollEvent{ @@ -163,7 +163,7 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { // dispatch reads packets from an mmaped ring buffer and dispatches them to the // network stack. -func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { +func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { pkt, err := d.readMMappedPacket() if err != nil { return false, err diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 8c3ca86d6..ecae1ad2d 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -29,92 +29,124 @@ import ( // BufConfig defines the shape of the vectorised view used to read packets from the NIC. var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} -// readVDispatcher uses readv() system call to read inbound packets and -// dispatches them. -type readVDispatcher struct { - // fd is the file descriptor used to send and receive packets. - fd int - - // e is the endpoint this dispatcher is attached to. - e *endpoint - +type iovecBuffer struct { // views are the actual buffers that hold the packet contents. views []buffer.View // iovecs are initialized with base pointers/len of the corresponding - // entries in the views defined above, except when GSO is enabled then - // the first iovec points to a buffer for the vnet header which is - // stripped before the views are passed up the stack for further + // entries in the views defined above, except when GSO is enabled + // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header + // which is stripped before the views are passed up the stack for further // processing. iovecs []syscall.Iovec + + // sizes is an array of buffer sizes for the underlying views. sizes is + // immutable. + sizes []int + + // skipsVnetHdr is true if virtioNetHdr is to skipped. + skipsVnetHdr bool } -func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { - d := &readVDispatcher{fd: fd, e: e} - d.views = make([]buffer.View, len(BufConfig)) - iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - iovLen++ +func newIovecBuffer(sizes []int, skipsVnetHdr bool) *iovecBuffer { + b := &iovecBuffer{ + views: make([]buffer.View, len(sizes)), + sizes: sizes, + skipsVnetHdr: skipsVnetHdr, } - d.iovecs = make([]syscall.Iovec, iovLen) - return d, nil + niov := len(b.views) + if b.skipsVnetHdr { + niov++ + } + b.iovecs = make([]syscall.Iovec, niov) + return b } -func (d *readVDispatcher) allocateViews(bufConfig []int) { - var vnetHdr [virtioNetHdrSize]byte +func (b *iovecBuffer) nextIovecs() []syscall.Iovec { vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if b.skipsVnetHdr { + var vnetHdr [virtioNetHdrSize]byte // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. - d.iovecs[0] = syscall.Iovec{ + b.iovecs[0] = syscall.Iovec{ Base: &vnetHdr[0], Len: uint64(virtioNetHdrSize), } vnetHdrOff++ } - for i := 0; i < len(bufConfig); i++ { - if d.views[i] != nil { + for i := range b.views { + if b.views[i] != nil { break } - b := buffer.NewView(bufConfig[i]) - d.views[i] = b - d.iovecs[i+vnetHdrOff] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), + v := buffer.NewView(b.sizes[i]) + b.views[i] = v + b.iovecs[i+vnetHdrOff] = syscall.Iovec{ + Base: &v[0], + Len: uint64(len(v)), } } + return b.iovecs } -func (d *readVDispatcher) capViews(n int, buffers []int) int { +func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView { + var views []buffer.View c := 0 - for i, s := range buffers { - c += s + if b.skipsVnetHdr { + c += virtioNetHdrSize if c >= n { - d.views[i].CapLength(s - (c - n)) - return i + 1 + // Nothing in the packet. + return buffer.NewVectorisedView(0, nil) + } + } + for i, v := range b.views { + c += len(v) + if c >= n { + b.views[i].CapLength(len(v) - (c - n)) + views = append([]buffer.View(nil), b.views[:i+1]...) + break } } - return len(buffers) + // Remove the first len(views) used views from the state. + for i := range views { + b.views[i] = nil + } + if b.skipsVnetHdr { + // Exclude the size of the vnet header. + n -= virtioNetHdrSize + } + return buffer.NewVectorisedView(n, views) } -// dispatch reads one packet from the file descriptor and dispatches it. -func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { - d.allocateViews(BufConfig) +// readVDispatcher uses readv() system call to read inbound packets and +// dispatches them. +type readVDispatcher struct { + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // buf is the iovec buffer that contains the packet contents. + buf *iovecBuffer +} + +func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + d := &readVDispatcher{fd: fd, e: e} + skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) + return d, nil +} - n, err := rawfile.BlockingReadv(d.fd, d.iovecs) +// dispatch reads one packet from the file descriptor and dispatches it. +func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { + n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs()) if n == 0 || err != nil { return false, err } - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // Skip virtioNetHdr which is added before each packet, it - // isn't used and it isn't in a view. - n -= virtioNetHdrSize - } - used := d.capViews(n, BufConfig) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + Data: d.buf.pullViews(n), }) var ( @@ -133,7 +165,12 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(d.views[0]) { + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data.PullUp(1) + if !ok { + return true, nil + } + switch header.IPVersion(h) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: @@ -145,11 +182,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) - // Prepare e.views for another packet: release used views. - for i := 0; i < used; i++ { - d.views[i] = nil - } - return true, nil } @@ -162,15 +194,8 @@ type recvMMsgDispatcher struct { // e is the endpoint this dispatcher is attached to. e *endpoint - // views is an array of array of buffers that contain packet contents. - views [][]buffer.View - - // iovecs is an array of array of iovec records where each iovec base - // pointer and length are initialzed to the corresponding view above, - // except when GSO is enabled then the first iovec in each array of - // iovecs points to a buffer for the vnet header which is stripped - // before the views are passed up the stack for further processing. - iovecs [][]syscall.Iovec + // bufs is an array of iovec buffers that contain packet contents. + bufs []*iovecBuffer // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to // reference an array of iovecs in the iovecs field defined above. This @@ -187,74 +212,32 @@ const ( func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &recvMMsgDispatcher{ - fd: fd, - e: e, - } - d.views = make([][]buffer.View, MaxMsgsPerRecv) - for i := range d.views { - d.views[i] = make([]buffer.View, len(BufConfig)) - } - d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv) - iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // virtioNetHdr is prepended before each packet. - iovLen++ + fd: fd, + e: e, + bufs: make([]*iovecBuffer, MaxMsgsPerRecv), + msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, iovLen) - } - d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv) - for i := range d.msgHdrs { - d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0] - d.msgHdrs[i].Msg.Iovlen = uint64(iovLen) + skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + for i := range d.bufs { + d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) } return d, nil } -func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int { - c := 0 - for i, s := range buffers { - c += s - if c >= n { - d.views[k][i].CapLength(s - (c - n)) - return i + 1 - } - } - return len(buffers) -} - -func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) { - for k := 0; k < len(d.views); k++ { - var vnetHdr [virtioNetHdrSize]byte - vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // The kernel adds virtioNetHdr before each packet, but - // we don't use it, so so we allocate a buffer for it, - // add it in iovecs but don't add it in a view. - d.iovecs[k][0] = syscall.Iovec{ - Base: &vnetHdr[0], - Len: uint64(virtioNetHdrSize), - } - vnetHdrOff++ - } - for i := 0; i < len(bufConfig); i++ { - if d.views[k][i] != nil { - break - } - b := buffer.NewView(bufConfig[i]) - d.views[k][i] = b - d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), - } - } - } -} - // recvMMsgDispatch reads more than one packet at a time from the file // descriptor and dispatches it. -func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { - d.allocateViews(BufConfig) +func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { + // Fill message headers. + for k := range d.msgHdrs { + if d.msgHdrs[k].Msg.Iovlen > 0 { + break + } + iovecs := d.bufs[k].nextIovecs() + iovLen := len(iovecs) + d.msgHdrs[k].Len = 0 + d.msgHdrs[k].Msg.Iov = &iovecs[0] + d.msgHdrs[k].Msg.Iovlen = uint64(iovLen) + } nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs) if err != nil { @@ -263,15 +246,14 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { // Process each of received packets. for k := 0; k < nMsgs; k++ { n := int(d.msgHdrs[k].Len) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - n -= virtioNetHdrSize - } - used := d.capViews(k, int(n), BufConfig) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + Data: d.bufs[k].pullViews(n), }) + // Mark that this iovec has been processed. + d.msgHdrs[k].Msg.Iovlen = 0 + var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress @@ -288,26 +270,24 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(d.views[k][0]) { + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data.PullUp(1) + if !ok { + // Skip this packet. + continue + } + switch header.IPVersion(h) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: p = header.IPv6ProtocolNumber default: - return true, nil + // Skip this packet. + continue } } d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) - - // Prepare e.views for another packet: release used views. - for i := 0; i < used; i++ { - d.views[k][i] = nil - } - } - - for k := 0; k < nMsgs; k++ { - d.msgHdrs[k].Len = 0 } return true, nil diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ac6a6be87..691467870 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // Construct data as the unparsed portion for the loopback packet. data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 316f508e6..668f72eee 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,10 +87,10 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } return endpoint.WritePackets(r, gso, pkts, protocol) } @@ -98,19 +98,19 @@ func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkt // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } // InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error { endpoint, ok := m.routes[dest] if !ok { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } return endpoint.InjectOutbound(dest, packet) } diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 814a54f23..97ad9fdd5 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -113,12 +113,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return e.child.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return e.child.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index c95cdd681..6cbe18a56 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -35,13 +35,13 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index d6e83a414..bbe84f220 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -45,12 +45,7 @@ type Endpoint struct { linkAddr tcpip.LinkAddress } -// WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - if !e.linked.IsAttached() { - return nil - } - +func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) { // Note that the local address from the perspective of this endpoint is the // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this @@ -70,16 +65,33 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw // // TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send // and receive queues. - go e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) + go func() { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + } + }() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if e.linked.IsAttached() { + var pkts stack.PacketBufferList + pkts.PushBack(pkt) + e.deliverPackets(r, proto, pkts) + } return nil } // WritePackets implements stack.LinkEndpoint. -func (*Endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - panic("not implemented") +func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + if e.linked.IsAttached() { + e.deliverPackets(r, proto, pkts) + } + + return pkts.Len(), nil } // Attach implements stack.LinkEndpoint. diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 87035b034..128ef6e87 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -150,7 +150,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. pkt.EgressRoute = r @@ -158,26 +158,29 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] if !d.q.enqueue(pkt) { - return tcpip.ErrNoBufferSpace + return &tcpip.ErrNoBufferSpace{} } d.newPacketWaker.Assert() return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +// +// Being a batch API, each packet in pkts should have the following +// fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { enqueued := 0 for pkt := pkts.Front(); pkt != nil; { - pkt.EgressRoute = r - pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] nxt := pkt.Next() if !d.q.enqueue(pkt) { if enqueued > 0 { d.newPacketWaker.Assert() } - return enqueued, tcpip.ErrNoBufferSpace + return enqueued, &tcpip.ErrNoBufferSpace{} } pkt = nxt enqueued++ diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 6c410c5a6..e1047da50 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -27,5 +27,6 @@ go_test( library = "rawfile", deps = [ "//pkg/tcpip", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index 604868fd8..406b97709 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -17,7 +17,6 @@ package rawfile import ( - "fmt" "syscall" "gvisor.dev/gvisor/pkg/tcpip" @@ -25,48 +24,54 @@ import ( const maxErrno = 134 -var translations [maxErrno]*tcpip.Error - // TranslateErrno translate an errno from the syscall package into a -// *tcpip.Error. +// tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). -func TranslateErrno(e syscall.Errno) *tcpip.Error { - if e > 0 && e < syscall.Errno(len(translations)) { - if err := translations[e]; err != nil { - return err - } - } - return tcpip.ErrInvalidEndpointState -} - -func addTranslation(host syscall.Errno, trans *tcpip.Error) { - if translations[host] != nil { - panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host)) +// *tcpip.ErrInvalidEndpointState (EINVAL). +func TranslateErrno(e syscall.Errno) tcpip.Error { + switch e { + case syscall.EEXIST: + return &tcpip.ErrDuplicateAddress{} + case syscall.ENETUNREACH: + return &tcpip.ErrNoRoute{} + case syscall.EINVAL: + return &tcpip.ErrInvalidEndpointState{} + case syscall.EALREADY: + return &tcpip.ErrAlreadyConnecting{} + case syscall.EISCONN: + return &tcpip.ErrAlreadyConnected{} + case syscall.EADDRINUSE: + return &tcpip.ErrPortInUse{} + case syscall.EADDRNOTAVAIL: + return &tcpip.ErrBadLocalAddress{} + case syscall.EPIPE: + return &tcpip.ErrClosedForSend{} + case syscall.EWOULDBLOCK: + return &tcpip.ErrWouldBlock{} + case syscall.ECONNREFUSED: + return &tcpip.ErrConnectionRefused{} + case syscall.ETIMEDOUT: + return &tcpip.ErrTimeout{} + case syscall.EINPROGRESS: + return &tcpip.ErrConnectStarted{} + case syscall.EDESTADDRREQ: + return &tcpip.ErrDestinationRequired{} + case syscall.ENOTSUP: + return &tcpip.ErrNotSupported{} + case syscall.ENOTTY: + return &tcpip.ErrQueueSizeNotSupported{} + case syscall.ENOTCONN: + return &tcpip.ErrNotConnected{} + case syscall.ECONNRESET: + return &tcpip.ErrConnectionReset{} + case syscall.ECONNABORTED: + return &tcpip.ErrConnectionAborted{} + case syscall.EMSGSIZE: + return &tcpip.ErrMessageTooLong{} + case syscall.ENOBUFS: + return &tcpip.ErrNoBufferSpace{} + default: + return &tcpip.ErrInvalidEndpointState{} } - translations[host] = trans -} - -func init() { - addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress) - addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute) - addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState) - addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting) - addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected) - addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse) - addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress) - addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend) - addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock) - addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused) - addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout) - addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted) - addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired) - addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported) - addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported) - addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected) - addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset) - addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted) - addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong) - addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace) } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go index e4cdc66bd..61aea1744 100644 --- a/pkg/tcpip/link/rawfile/errors_test.go +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -20,34 +20,35 @@ import ( "syscall" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) func TestTranslateErrno(t *testing.T) { for _, test := range []struct { errno syscall.Errno - translated *tcpip.Error + translated tcpip.Error }{ { errno: syscall.Errno(0), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.Errno(maxErrno), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.Errno(514), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.EEXIST, - translated: tcpip.ErrDuplicateAddress, + translated: &tcpip.ErrDuplicateAddress{}, }, } { got := TranslateErrno(test.errno) - if got != test.translated { - t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + if diff := cmp.Diff(test.translated, got); diff != "" { + t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff) } } } diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index f4c32c2da..06f3ee21e 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -52,7 +52,7 @@ func GetMTU(name string) (uint32, error) { // NonBlockingWrite writes the given buffer to a file descriptor. It fails if // partial data is written. -func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { +func NonBlockingWrite(fd int, buf []byte) tcpip.Error { var ptr unsafe.Pointer if len(buf) > 0 { ptr = unsafe.Pointer(&buf[0]) @@ -68,7 +68,7 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { // NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. // It fails if partial data is written. -func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) tcpip.Error { iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { @@ -78,7 +78,7 @@ func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { } // NonBlockingSendMMsg sends multiple messages on a socket. -func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { +func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) if e != 0 { return 0, TranslateErrno(e) @@ -97,7 +97,7 @@ type PollEvent struct { // BlockingRead reads from a file descriptor that is set up as non-blocking. If // no data is available, it will block in a poll() syscall until the file // descriptor becomes readable. -func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { +func BlockingRead(fd int, b []byte) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) if e == 0 { @@ -119,7 +119,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { // BlockingReadv reads from a file descriptor that is set up as non-blocking and // stores the data in a list of iovecs buffers. If no data is available, it will // block in a poll() syscall until the file descriptor becomes readable. -func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { +func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) if e == 0 { @@ -149,7 +149,7 @@ type MMsgHdr struct { // and stores the received messages in a slice of MMsgHdr structures. If no data // is available, it will block in a poll() syscall until the file descriptor // becomes readable. -func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { +func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) if e == 0 { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6c937c858..2599bc406 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) views := pkt.Views() @@ -213,14 +213,14 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N e.mu.Unlock() if !ok { - return tcpip.ErrWouldBlock + return &tcpip.ErrWouldBlock{} } return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 23242b9e0..d480ad656 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -425,8 +425,9 @@ func TestFillTxQueue(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -493,8 +494,9 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -538,8 +540,8 @@ func TestFillTxMemory(t *testing.T) { Data: buf.ToVectorisedView(), }) err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) - if want := tcpip.ErrWouldBlock; err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -579,8 +581,9 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buffer.NewView(bufferSize).ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 5859851d8..bd2b8d4bf 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -187,7 +187,7 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.dumpPacket(directionSend, gso, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } @@ -195,7 +195,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.dumpPacket(directionSend, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index bfac358f4..3829ca9c9 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -149,10 +149,10 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ Name: endpoint.name, }) - switch err { + switch err.(type) { case nil: return endpoint, nil - case tcpip.ErrDuplicateNICID: + case *tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 30f1ad540..20259b285 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -108,7 +108,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if !e.writeGate.Enter() { return nil } @@ -121,7 +121,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { if !e.writeGate.Enter() { return pkts.Len(), nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index b139de7dd..e368a9eaa 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { e.writeCount += pkts.Len() return pkts.Len(), nil } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9ebf31b78..0caa65251 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -25,5 +25,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 8a6bcfc2c..c7ab876bf 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -4,9 +4,13 @@ package(licenses = ["notice"]) go_library( name = "arp", - srcs = ["arp.go"], + srcs = [ + "arp.go", + "stats.go", + ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -33,3 +37,15 @@ go_test( "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) + +go_test( + name = "stats_test", + size = "small", + srcs = ["stats_test.go"], + library = ":arp", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 3259d052f..7838cc753 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -19,8 +19,10 @@ package arp import ( "fmt" + "reflect" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,11 +52,12 @@ type endpoint struct { nic stack.NetworkInterface linkAddrCache stack.LinkAddressCache nud stack.NUDHandler + stats sharedStats } -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } e.setEnabled(true) @@ -98,10 +101,12 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (*endpoint) Close() {} +func (e *endpoint) Close() { + e.protocol.forgetEndpoint(e.nic.ID()) +} -func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -110,51 +115,45 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported +func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { + return 0, &tcpip.ErrNotSupported{} } -func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats().ARP - stats.PacketsReceived.Increment() + stats := e.stats.arp + stats.packetsReceived.Increment() if !e.isEnabled() { - stats.DisabledPacketsReceived.Increment() + stats.disabledPacketsReceived.Increment() return } h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { - stats.MalformedPacketsReceived.Increment() + stats.malformedPacketsReceived.Increment() return } switch h.Op() { case header.ARPRequest: - stats.RequestsReceived.Increment() + stats.requestsReceived.Increment() localAddr := tcpip.Address(h.ProtocolAddressTarget()) + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + stats.requestsReceivedUnknownTargetAddress.Increment() + return // we have no useful answer, ignore the request + } + + remoteAddr := tcpip.Address(h.ProtocolAddressSender()) + remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + if e.nud == nil { - if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { - stats.RequestsReceivedUnknownTargetAddress.Increment() - return // we have no useful answer, ignore the request - } - - addr := tcpip.Address(h.ProtocolAddressSender()) - linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) } else { - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { - stats.RequestsReceivedUnknownTargetAddress.Increment() - return // we have no useful answer, ignore the request - } - - remoteAddr := tcpip.Address(h.ProtocolAddressSender()) - remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } @@ -186,18 +185,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Send the packet to the (new) target hardware address on the same // hardware on which the request was received. if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil { - stats.OutgoingRepliesDropped.Increment() + stats.outgoingRepliesDropped.Increment() } else { - stats.OutgoingRepliesSent.Increment() + stats.outgoingRepliesSent.Increment() } case header.ARPReply: - stats.RepliesReceived.Increment() + stats.repliesReceived.Increment() addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + e.linkAddrCache.AddLinkAddress(addr, linkAddr) return } @@ -216,9 +215,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + +var _ stack.NetworkProtocol = (*protocol)(nil) +var _ stack.LinkAddressResolver = (*protocol)(nil) + // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack + + mu struct { + sync.RWMutex + + // eps is keyed by NICID to allow protocol methods to retrieve the correct + // endpoint depending on the NIC. + eps map[tcpip.NICID]*endpoint + } } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -236,39 +251,62 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L linkAddrCache: linkAddrCache, nud: nud, } + + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + + stackStats := p.stack.Stats() + e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP) + + p.mu.Lock() + p.mu.eps[nic.ID()] = e + p.mu.Unlock() + return e } +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, nicID) +} + // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { - stats := p.stack.Stats().ARP +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { + nicID := nic.ID() + + p.mu.Lock() + netEP, ok := p.mu.eps[nicID] + p.mu.Unlock() + if !ok { + return &tcpip.ErrNotConnected{} + } + + stats := netEP.stats.arp if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } - nicID := nic.ID() if len(localAddr) == 0 { - addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) - if err != nil { - stats.OutgoingRequestInterfaceHasNoLocalAddressErrors.Increment() - return err + addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if !ok { + return &tcpip.ErrUnknownNICID{} } if len(addr.Address) == 0 { - stats.OutgoingRequestNetworkUnreachableErrors.Increment() - return tcpip.ErrNetworkUnreachable + stats.outgoingRequestInterfaceHasNoLocalAddressErrors.Increment() + return &tcpip.ErrNetworkUnreachable{} } localAddr = addr.Address } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { - stats.OutgoingRequestBadLocalAddressErrors.Increment() - return tcpip.ErrBadLocalAddress + stats.outgoingRequestBadLocalAddressErrors.Increment() + return &tcpip.ErrBadLocalAddress{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -288,10 +326,10 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { - stats.OutgoingRequestsDropped.Increment() + stats.outgoingRequestsDropped.Increment() return err } - stats.OutgoingRequestsSent.Increment() + stats.outgoingRequestsSent.Increment() return nil } @@ -307,13 +345,13 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } // SetOption implements stack.NetworkProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.NetworkProtocol.Option. -func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. @@ -329,5 +367,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu // NewProtocol returns an ARP network protocol. func NewProtocol(s *stack.Stack) stack.NetworkProtocol { - return &protocol{stack: s} + return &protocol{ + stack: s, + mu: struct { + sync.RWMutex + eps map[tcpip.NICID]*endpoint + }{eps: make(map[tcpip.NICID]*endpoint)}, + } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 0536e1698..b0f07aa44 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -125,8 +125,8 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { select { case got := <-d.C: - if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-got +want):\n%s", diff) + if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) } case <-ctx.Done(): return fmt.Errorf("%s for %s", ctx.Err(), want) @@ -537,7 +537,7 @@ type testInterface struct { nicID tcpip.NICID - writeErr *tcpip.Error + writeErr tcpip.Error } func (t *testInterface) ID() tcpip.NICID { @@ -560,15 +560,15 @@ func (*testInterface) Promiscuous() bool { return false } -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) } -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) } -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if t.writeErr != nil { return t.writeErr } @@ -585,104 +585,122 @@ func TestLinkAddressRequest(t *testing.T) { testAddr := tcpip.Address([]byte{1, 2, 3, 4}) tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - - linkErr *tcpip.Error - expectedErr *tcpip.Error - expectedLocalAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - expectedRequestsSent uint64 - expectedRequestBadLocalAddressErrors uint64 - expectedRequestNetworkUnreachableErrors uint64 - expectedRequestDroppedErrors uint64 + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + linkErr tcpip.Error + expectedErr tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress + expectedRequestsSent uint64 + expectedRequestBadLocalAddressErrors uint64 + expectedRequestInterfaceHasNoLocalAddressErrors uint64 + expectedRequestDroppedErrors uint64 }{ { - name: "Unicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast with unspecified source", + nicAddr: stackAddr, + localAddr: "", + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast with unspecified source", + nicAddr: stackAddr, + localAddr: "", + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrBadLocalAddress, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast with unassigned address", + nicAddr: stackAddr, + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: &tcpip.ErrBadLocalAddress{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: "", - expectedErr: tcpip.ErrBadLocalAddress, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast with unassigned address", + nicAddr: stackAddr, + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: &tcpip.ErrBadLocalAddress{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with no local address available", - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrNetworkUnreachable, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 1, + name: "Unicast with no local address available", + nicAddr: "", + localAddr: "", + remoteLinkAddr: remoteLinkAddr, + expectedErr: &tcpip.ErrNetworkUnreachable{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 1, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with no local address available", - remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 1, + name: "Multicast with no local address available", + nicAddr: "", + localAddr: "", + remoteLinkAddr: "", + expectedErr: &tcpip.ErrNetworkUnreachable{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 1, + expectedRequestDroppedErrors: 0, }, { - name: "Link error", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - linkErr: tcpip.ErrInvalidEndpointState, - expectedErr: tcpip.ErrInvalidEndpointState, - expectedRequestDroppedErrors: 1, + name: "Link error", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + linkErr: &tcpip.ErrInvalidEndpointState{}, + expectedErr: &tcpip.ErrInvalidEndpointState{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 1, }, } @@ -714,19 +732,20 @@ func TestLinkAddressRequest(t *testing.T) { // link endpoint even though the stack uses the real NIC to validate the // local address. iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr} - if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) } if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent) } + if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors) + } if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors { t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors) } - if got := s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value(); got != test.expectedRequestNetworkUnreachableErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value() = %d, want = %d", got, test.expectedRequestNetworkUnreachableErrors) - } if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors { t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors) } @@ -774,11 +793,8 @@ func TestLinkAddressRequestWithoutNIC(t *testing.T) { t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") } - if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrUnknownNICID { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrUnknownNICID) - } - - if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != 1 { - t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = 1", got) + err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}) + if _, ok := err.(*tcpip.ErrNotConnected); !ok { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{}) } } diff --git a/pkg/tcpip/network/arp/stats.go b/pkg/tcpip/network/arp/stats.go new file mode 100644 index 000000000..6d7194c6c --- /dev/null +++ b/pkg/tcpip/network/arp/stats.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arp + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to ARP. +type Stats struct { + // ARP holds ARP statistics. + ARP tcpip.ARPStats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*Stats) IsNetworkEndpointStats() {} + +type sharedStats struct { + localStats Stats + arp multiCounterARPStats +} + +// LINT.IfChange(multiCounterARPStats) + +type multiCounterARPStats struct { + packetsReceived tcpip.MultiCounterStat + disabledPacketsReceived tcpip.MultiCounterStat + malformedPacketsReceived tcpip.MultiCounterStat + requestsReceived tcpip.MultiCounterStat + requestsReceivedUnknownTargetAddress tcpip.MultiCounterStat + outgoingRequestInterfaceHasNoLocalAddressErrors tcpip.MultiCounterStat + outgoingRequestBadLocalAddressErrors tcpip.MultiCounterStat + outgoingRequestsDropped tcpip.MultiCounterStat + outgoingRequestsSent tcpip.MultiCounterStat + repliesReceived tcpip.MultiCounterStat + outgoingRepliesDropped tcpip.MultiCounterStat + outgoingRepliesSent tcpip.MultiCounterStat +} + +func (m *multiCounterARPStats) init(a, b *tcpip.ARPStats) { + m.packetsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.disabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) + m.malformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived) + m.requestsReceived.Init(a.RequestsReceived, b.RequestsReceived) + m.requestsReceivedUnknownTargetAddress.Init(a.RequestsReceivedUnknownTargetAddress, b.RequestsReceivedUnknownTargetAddress) + m.outgoingRequestInterfaceHasNoLocalAddressErrors.Init(a.OutgoingRequestInterfaceHasNoLocalAddressErrors, b.OutgoingRequestInterfaceHasNoLocalAddressErrors) + m.outgoingRequestBadLocalAddressErrors.Init(a.OutgoingRequestBadLocalAddressErrors, b.OutgoingRequestBadLocalAddressErrors) + m.outgoingRequestsDropped.Init(a.OutgoingRequestsDropped, b.OutgoingRequestsDropped) + m.outgoingRequestsSent.Init(a.OutgoingRequestsSent, b.OutgoingRequestsSent) + m.repliesReceived.Init(a.RepliesReceived, b.RepliesReceived) + m.outgoingRepliesDropped.Init(a.OutgoingRepliesDropped, b.OutgoingRepliesDropped) + m.outgoingRepliesSent.Init(a.OutgoingRepliesSent, b.OutgoingRepliesSent) +} + +// LINT.ThenChange(../../tcpip.go:ARPStats) diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go new file mode 100644 index 000000000..036fdf739 --- /dev/null +++ b/pkg/tcpip/network/arp/stats_test.go @@ -0,0 +1,93 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arp + +import ( + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.NetworkInterface + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func knownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + nic := testInterface{nicID: 1} + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) + } + + ep.Close() + + proto.mu.Lock() + _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEndpointAfterClose { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } +} + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // expected to be bound by a MultiCounterStat. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD index ca1247c1e..411bca25d 100644 --- a/pkg/tcpip/network/ip/BUILD +++ b/pkg/tcpip/network/ip/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "ip", - srcs = ["generic_multicast_protocol.go"], + srcs = [ + "generic_multicast_protocol.go", + "stats.go", + ], visibility = ["//visibility:public"], deps = [ "//pkg/sync", diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index f2f0e069c..a81f5c8c3 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -174,10 +174,10 @@ type MulticastGroupProtocol interface { // // Returns false if the caller should queue the report to be sent later. Note, // returning false does not mean that the receiver hit an error. - SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) + SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error) // SendLeave sends a multicast leave for the specified group address. - SendLeave(groupAddress tcpip.Address) *tcpip.Error + SendLeave(groupAddress tcpip.Address) tcpip.Error } // GenericMulticastProtocolState is the per interface generic multicast protocol diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 85593f211..d5d5a449e 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -141,7 +141,7 @@ func (m *mockMulticastGroupProtocol) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { if m.mu.TryLock() { m.mu.Unlock() m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) @@ -158,7 +158,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { if m.mu.TryLock() { m.mu.Unlock() m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) diff --git a/pkg/tcpip/network/ip/stats.go b/pkg/tcpip/network/ip/stats.go new file mode 100644 index 000000000..898f8b356 --- /dev/null +++ b/pkg/tcpip/network/ip/stats.go @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import "gvisor.dev/gvisor/pkg/tcpip" + +// LINT.IfChange(MultiCounterIPStats) + +// MultiCounterIPStats holds IP statistics, each counter may have several +// versions. +type MultiCounterIPStats struct { + // PacketsReceived is the total number of IP packets received from the link + // layer. + PacketsReceived tcpip.MultiCounterStat + + // DisabledPacketsReceived is the total number of IP packets received from the + // link layer when the IP layer is disabled. + DisabledPacketsReceived tcpip.MultiCounterStat + + // InvalidDestinationAddressesReceived is the total number of IP packets + // received with an unknown or invalid destination address. + InvalidDestinationAddressesReceived tcpip.MultiCounterStat + + // InvalidSourceAddressesReceived is the total number of IP packets received + // with a source address that should never have been received on the wire. + InvalidSourceAddressesReceived tcpip.MultiCounterStat + + // PacketsDelivered is the total number of incoming IP packets that are + // successfully delivered to the transport layer. + PacketsDelivered tcpip.MultiCounterStat + + // PacketsSent is the total number of IP packets sent via WritePacket. + PacketsSent tcpip.MultiCounterStat + + // OutgoingPacketErrors is the total number of IP packets which failed to + // write to a link-layer endpoint. + OutgoingPacketErrors tcpip.MultiCounterStat + + // MalformedPacketsReceived is the total number of IP Packets that were + // dropped due to the IP packet header failing validation checks. + MalformedPacketsReceived tcpip.MultiCounterStat + + // MalformedFragmentsReceived is the total number of IP Fragments that were + // dropped due to the fragment failing validation checks. + MalformedFragmentsReceived tcpip.MultiCounterStat + + // IPTablesPreroutingDropped is the total number of IP packets dropped in the + // Prerouting chain. + IPTablesPreroutingDropped tcpip.MultiCounterStat + + // IPTablesInputDropped is the total number of IP packets dropped in the Input + // chain. + IPTablesInputDropped tcpip.MultiCounterStat + + // IPTablesOutputDropped is the total number of IP packets dropped in the + // Output chain. + IPTablesOutputDropped tcpip.MultiCounterStat + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived tcpip.MultiCounterStat + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived tcpip.MultiCounterStat + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived tcpip.MultiCounterStat +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { + m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) + m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) + m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) + m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered) + m.PacketsSent.Init(a.PacketsSent, b.PacketsSent) + m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors) + m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived) + m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) + m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) + m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) + m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) + m.OptionTSReceived.Init(a.OptionTSReceived, b.OptionTSReceived) + m.OptionRRReceived.Init(a.OptionRRReceived, b.OptionRRReceived) + m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) +} + +// LINT.ThenChange(:MultiCounterIPStats, ../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 3005973d7..47cce79bb 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -18,6 +18,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -167,7 +168,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -189,7 +190,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } @@ -203,7 +204,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net panic("not implemented") } -func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { +func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -219,7 +220,7 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } -func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { +func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -235,14 +236,14 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } -func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) { +func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) { t.Helper() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) - e := channel.New(0, 1280, "") + e := channel.New(1, mtu, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -263,7 +264,7 @@ func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpo func buildDummyStack(t *testing.T) *stack.Stack { t.Helper() - s, _ := buildDummyStackWithLinkEndpoint(t) + s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) return s } @@ -306,8 +307,8 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func TestSourceAddressValidation(t *testing.T) { @@ -416,7 +417,7 @@ func TestSourceAddressValidation(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t) + s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) test.rxICMP(e, test.srcAddress) var wantValid uint64 @@ -479,8 +480,9 @@ func TestEnableWhenNICDisabled(t *testing.T) { // Attempting to enable the endpoint while the NIC is disabled should // fail. nic.setEnabled(false) - if err := ep.Enable(); err != tcpip.ErrNotPermitted { - t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted) + err := ep.Enable() + if _, ok := err.(*tcpip.ErrNotPermitted); !ok { + t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{}) } // ep should consider the NIC's enabled status when determining its own // enabled status so we "enable" the NIC to read just the endpoint's @@ -1122,7 +1124,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { remoteAddr tcpip.Address pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) - expectedErr *tcpip.Error + expectedErr tcpip.Error }{ { name: "IPv4", @@ -1187,7 +1189,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { ip.SetHeaderLength(header.IPv4MinimumSize - 1) return hdr.View().ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, { name: "IPv4 too small", @@ -1205,7 +1207,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, { name: "IPv4 minimum size", @@ -1465,7 +1467,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, } @@ -1490,7 +1492,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, }) - e := channel.New(1, 1280, "") + e := channel.New(1, header.IPv6MinimumMTU, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -1506,10 +1508,13 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } defer r.Release() - if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr), - })); err != test.expectedErr { - t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) + { + err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.pktGen(t, subTest.srcAddr), + })) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff) + } } if test.expectedErr != nil { @@ -1526,3 +1531,246 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) } } + +// Test that the included data in an ICMP error packet conforms to the +// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3 +func TestICMPInclusionSize(t *testing.T) { + const ( + replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize + replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize + targetSize4 = header.IPv4MinimumProcessableDatagramSize + targetSize6 = header.IPv6MinimumMTU + // A protocol number that will cause an error response. + reservedProtocol = 254 + ) + + // IPv4 function to create a IP packet and send it to the stack. + // The packet should generate an error response. We can do that by using an + // unknown transport protocol (254). + rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { + totalLen := header.IPv4MinimumSize + len(payload) + hdr := buffer.NewPrependable(header.IPv4MinimumSize) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: reservedProtocol, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + vv := hdr.View().ToVectorisedView() + vv.AppendView(buffer.View(payload)) + // Take a copy before InjectInbound takes ownership of vv + // as vv may be changed during the call. + v := vv.ToView() + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + return v + } + + // IPv6 function to create a packet and send it to the stack. + // The packet should be errant in a way that causes the stack to send an + // ICMP error response and have enough data to allow the testing of the + // inclusion of the errant packet. Use `unknown next header' to generate + // the error. + rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload)), + TransportProtocol: reservedProtocol, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, + }) + vv := hdr.View().ToVectorisedView() + vv.AppendView(buffer.View(payload)) + // Take a copy before InjectInbound takes ownership of vv + // as vv may be changed during the call. + v := vv.ToView() + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + return v + } + + v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { + // We already know the entire packet is the right size so we can use its + // length to calculate the right payload size to check. + expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize + checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()), + checker.SrcAddr(localIPv4Addr), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4ProtoUnreachable), + checker.ICMPv4Payload(payload[:expectedPayloadLength]), + ), + ) + } + + v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { + // We already know the entire packet is the right size so we can use its + // length to calculate the right payload size to check. + expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize + checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()), + checker.SrcAddr(localIPv6Addr), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6ParamProblem), + checker.ICMPv6Code(header.ICMPv6UnknownHeader), + checker.ICMPv6Payload(payload[:expectedPayloadLength]), + ), + ) + } + tests := []struct { + name string + srcAddress tcpip.Address + injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View + checker func(*testing.T, *stack.PacketBuffer, buffer.View) + payloadLength int // Not including IP header. + linkMTU uint32 // Largest IP packet that the link can send as payload. + replyLength int // Total size of IP/ICMP packet expected back. + }{ + { + name: "IPv4 exact match", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 - replyHeaderLength4, + linkMTU: targetSize4, + replyLength: targetSize4, + }, + { + name: "IPv4 larger MTU", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4, + linkMTU: targetSize4 + 1000, + replyLength: targetSize4, + }, + { + name: "IPv4 smaller MTU", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4, + linkMTU: targetSize4 - 50, + replyLength: targetSize4 - 50, + }, + { + name: "IPv4 payload exceeds", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 + 10, + linkMTU: targetSize4, + replyLength: targetSize4, + }, + { + name: "IPv4 1 byte less", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 - replyHeaderLength4 - 1, + linkMTU: targetSize4, + replyLength: targetSize4 - 1, + }, + { + name: "IPv4 No payload", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: 0, + linkMTU: targetSize4, + replyLength: replyHeaderLength4, + }, + { + name: "IPv6 exact match", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6 - replyHeaderLength6, + linkMTU: targetSize6, + replyLength: targetSize6, + }, + { + name: "IPv6 larger MTU", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6, + linkMTU: targetSize6 + 400, + replyLength: targetSize6, + }, + // NB. No "smaller MTU" test here as less than 1280 is not permitted + // in IPv6. + { + name: "IPv6 payload exceeds", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6, + linkMTU: targetSize6, + replyLength: targetSize6, + }, + { + name: "IPv6 1 byte less", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6 - replyHeaderLength6 - 1, + linkMTU: targetSize6, + replyLength: targetSize6 - 1, + }, + { + name: "IPv6 no payload", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: 0, + linkMTU: targetSize6, + replyLength: replyHeaderLength6, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU) + // Allocate and initialize the payload view. + payload := buffer.NewView(test.payloadLength) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + // Default routes for IPv4&6 so ICMP can find a route to the remote + // node when attempting to send the ICMP error Reply. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + v := test.injector(e, test.srcAddress, payload) + pkt, ok := e.Read() + if !ok { + t.Fatal("expected a packet to be written") + } + if got, want := pkt.Pkt.Size(), test.replyLength; got != want { + t.Fatalf("got %d bytes of icmp error packet, want %d", got, want) + } + test.checker(t, pkt.Pkt, v) + }) + } +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 32f53f217..330a7d170 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -8,6 +8,7 @@ go_library( "icmp.go", "igmp.go", "ipv4.go", + "stats.go", ], visibility = ["//visibility:public"], deps = [ @@ -49,3 +50,15 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "stats_test", + size = "small", + srcs = ["stats_test.go"], + library = ":ipv4", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 8e392f86c..3d93a2cd0 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ package ipv4 import ( - "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -62,21 +61,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - received := stats.ICMP.V4.PacketsReceived + received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { - received.Invalid.Increment() + received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because // we are the only handler, however other types do not cope well with @@ -106,19 +104,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - aux, tmp, err := e.processIPOptions(pkt, opts, op) - if err != nil { - switch { - case - errors.Is(err, header.ErrIPv4OptDuplicate), - errors.Is(err, errIPv4RecordRouteOptInvalidLength), - errors.Is(err, errIPv4RecordRouteOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) - stats.MalformedRcvdPackets.Increment() - stats.IP.MalformedPacketsReceived.Increment() + tmp, optProblem := e.processIPOptions(pkt, opts, op) + if optProblem != nil { + if optProblem.NeedICMP { + _ = e.protocol.returnError(&icmpReasonParamProblem{ + pointer: optProblem.Pointer, + }, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + e.stats.ip.MalformedPacketsReceived.Increment() } return } @@ -128,11 +121,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: - received.Echo.Increment() + received.echo.Increment() - sent := stats.ICMP.V4.PacketsSent + sent := e.stats.icmp.packetsSent if !e.protocol.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return } @@ -213,18 +206,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.EchoReply.Increment() + sent.echoReply.Increment() case header.ICMPv4EchoReply: - received.EchoReply.Increment() + received.echoReply.Increment() e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: - received.DstUnreachable.Increment() + received.dstUnreachable.Increment() pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { @@ -243,31 +236,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } case header.ICMPv4SrcQuench: - received.SrcQuench.Increment() + received.srcQuench.Increment() case header.ICMPv4Redirect: - received.Redirect.Increment() + received.redirect.Increment() case header.ICMPv4TimeExceeded: - received.TimeExceeded.Increment() + received.timeExceeded.Increment() case header.ICMPv4ParamProblem: - received.ParamProblem.Increment() + received.paramProblem.Increment() case header.ICMPv4Timestamp: - received.Timestamp.Increment() + received.timestamp.Increment() case header.ICMPv4TimestampReply: - received.TimestampReply.Increment() + received.timestampReply.Increment() case header.ICMPv4InfoRequest: - received.InfoRequest.Increment() + received.infoRequest.Increment() case header.ICMPv4InfoReply: - received.InfoReply.Increment() + received.infoReply.Increment() default: - received.Invalid.Increment() + received.invalid.Increment() } } @@ -317,7 +310,7 @@ func (*icmpReasonParamProblem) isICMPReason() {} // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { origIPHdr := header.IPv4(pkt.NetworkHeader().View()) origIPHdrSrc := origIPHdr.SourceAddress() origIPHdrDst := origIPHdr.DestinationAddress() @@ -379,9 +372,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - sent := p.stack.Stats().ICMP.V4.PacketsSent + p.mu.Lock() + netEP, ok := p.mu.eps[pkt.NICID] + p.mu.Unlock() + if !ok { + return &tcpip.ErrNotConnected{} + } + + sent := netEP.stats.icmp.packetsSent + if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return nil } @@ -435,13 +436,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi // systems implement the RFC 1812 definition and not the original // requirement. We treat 8 bytes as the minimum but will try send more. mtu := int(route.MTU()) - if mtu > header.IPv4MinimumProcessableDatagramSize { - mtu = header.IPv4MinimumProcessableDatagramSize + const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize + if mtu > maxIPData { + mtu = maxIPData } - headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize - available := int(mtu) - headerLen + available := mtu - header.ICMPv4MinimumSize - if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { + if available < len(origIPHdr)+header.ICMPv4MinimumErrorPayloadSize { return nil } @@ -464,36 +465,36 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi payload.CapLength(payloadLen) icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, + ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize, Data: payload, }) icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter *tcpip.StatCounter + var counter tcpip.MultiCounterStat switch reason := reason.(type) { case *icmpReasonPortUnreachable: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonProtoUnreachable: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonParamProblem: icmpHdr.SetType(header.ICMPv4ParamProblem) icmpHdr.SetCode(header.ICMPv4UnusedCode) icmpHdr.SetPointer(reason.pointer) - counter = sent.ParamProblem + counter = sent.paramProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } @@ -508,7 +509,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi }, icmpPkt, ); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } counter.Increment() diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index d9b5fe6ed..4cd0b3256 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -103,7 +103,7 @@ func (igmp *igmpState) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { igmpType := header.IGMPv2MembershipReport if igmp.v1Present() { igmpType = header.IGMPv1MembershipReport @@ -114,7 +114,7 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Erro // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error { // As per RFC 2236 Section 6, Page 8: "If the interface state says the // Querier is running IGMPv1, this action SHOULD be skipped. If the flag // saying we were the last host to report is cleared, this action MAY be @@ -149,51 +149,49 @@ func (igmp *igmpState) init(ep *endpoint) { // // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { - stats := igmp.ep.protocol.stack.Stats() - received := stats.IGMP.PacketsReceived + received := igmp.ep.stats.igmp.packetsReceived headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.IGMP(headerView) - // Temporarily reset the checksum field to 0 in order to calculate the proper - // checksum. - wantChecksum := h.Checksum() - h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - h.SetChecksum(wantChecksum) - - if gotChecksum != wantChecksum { - received.ChecksumErrors.Increment() + // As per RFC 1071 section 1.3, + // + // To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xFFFF { + received.checksumErrors.Increment() return } switch h.Type() { case header.IGMPMembershipQuery: - received.MembershipQuery.Increment() + received.membershipQuery.Increment() if len(headerView) < header.IGMPQueryMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) case header.IGMPv1MembershipReport: - received.V1MembershipReport.Increment() + received.v1MembershipReport.Increment() if len(headerView) < header.IGMPReportMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipReport(h.GroupAddress()) case header.IGMPv2MembershipReport: - received.V2MembershipReport.Increment() + received.v2MembershipReport.Increment() if len(headerView) < header.IGMPReportMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipReport(h.GroupAddress()) case header.IGMPLeaveGroup: - received.LeaveGroup.Increment() + received.leaveGroup.Increment() // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or // Report, are ignored in all states" @@ -201,7 +199,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should // be silently ignored. New message types may be used by newer versions of // IGMP, by multicast routing protocols, or other uses." - received.Unrecognized.Increment() + received.unrecognized.Increment() } } @@ -244,7 +242,7 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { // writePacket assembles and sends an IGMP packet. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, tcpip.Error) { igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) igmpData.SetType(igmpType) igmpData.SetGroupAddress(groupAddress) @@ -272,18 +270,18 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip panic(fmt.Sprintf("failed to add IP header: %s", err)) } - sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + sentStats := igmp.ep.stats.igmp.packetsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sentStats.Dropped.Increment() + sentStats.dropped.Increment() return false, err } switch igmpType { case header.IGMPv1MembershipReport: - sentStats.V1MembershipReport.Increment() + sentStats.v1MembershipReport.Increment() case header.IGMPv2MembershipReport: - sentStats.V2MembershipReport.Increment() + sentStats.v2MembershipReport.Increment() case header.IGMPLeaveGroup: - sentStats.LeaveGroup.Increment() + sentStats.leaveGroup.Increment() default: panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) } @@ -295,7 +293,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // messages. // // If the group already exists in the membership map, returns -// tcpip.ErrDuplicateAddress. +// *tcpip.ErrDuplicateAddress. // // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { @@ -314,13 +312,13 @@ func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { // if required. // // Precondition: igmp.ep.mu must be locked. -func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) tcpip.Error { // LeaveGroup returns false only if the group was not joined. if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } // softLeaveAll leaves all groups from the perspective of IGMP, but remains diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bb25a76fe..e5c80699d 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ package ipv4 import ( - "errors" "fmt" "math" + "reflect" "sync/atomic" "time" @@ -73,6 +73,7 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol + stats sharedStats // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -114,18 +115,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa e.mu.addressableEndpointState.Init(e) e.mu.igmp.init(e) e.mu.Unlock() + + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + + stackStats := p.stack.Stats() + e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP) + e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V4) + e.stats.igmp.init(&e.stats.localStats.IGMP, &stackStats.IGMP) + + p.mu.Lock() + p.mu.eps[nic.ID()] = e + p.mu.Unlock() + return e } +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, nicID) +} + // Enable implements stack.NetworkEndpoint. -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If the NIC is not enabled, the endpoint can't do anything meaningful so // don't enable the endpoint. if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } // If the endpoint is already enabled, there is nothing for it to do. @@ -193,7 +212,9 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } @@ -202,7 +223,9 @@ func (e *endpoint) disableLocked() { e.mu.igmp.softLeaveAll() // The address may have already been removed. - if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } @@ -237,7 +260,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) tcpip.Error { hdrLen := header.IPv4MinimumSize var optLen int if options != nil { @@ -245,12 +268,12 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet } hdrLen += optLen if hdrLen > header.IPv4MaximumHeaderSize { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := pkt.Size() if length > math.MaxUint16 { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams @@ -275,7 +298,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the // original packet. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { // Round the MTU down to align to 8 bytes. fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -295,17 +318,17 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { return err } // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. - e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() + e.stats.ip.IPTablesOutputDropped.Increment() return nil } @@ -334,7 +357,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return e.writePacket(r, gso, pkt, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { @@ -349,35 +372,37 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + stats := e.stats.ip + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) }) - r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + stats.PacketsSent.IncrementBy(uint64(sent)) + stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } - r.Stats().IP.PacketsSent.Increment() + stats.PacketsSent.Increment() return nil } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } @@ -385,6 +410,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + stats := e.stats.ip + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { return 0, err @@ -392,7 +419,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } @@ -400,7 +427,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt @@ -413,21 +440,21 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } return n, err } - r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. @@ -451,36 +478,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) + stats.PacketsSent.IncrementBy(uint64(n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err } n++ } - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) // Dropped packets aren't errors, so include them in the return value. return n + len(dropped), nil } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } hdrLen := header.IPv4(h).HeaderLength() if hdrLen < header.IPv4MinimumSize { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } h, ok = pkt.Data.PullUp(int(hdrLen)) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } ip := header.IPv4(h) @@ -518,14 +545,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // wire format. We also want to check if the header's fields are valid before // sending the packet. if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv4(pkt.NetworkHeader().View()) ttl := h.TTL() if ttl == 0 { @@ -545,7 +572,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { networkEndpoint.(*endpoint).handlePacket(pkt) return nil } - if err != tcpip.ErrBadAddress { + if _, ok := err.(*tcpip.ErrBadAddress); !ok { return err } @@ -577,19 +604,21 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - stats.IP.PacketsReceived.Increment() + stats := e.stats.ip + + stats.PacketsReceived.Increment() if !e.isEnabled() { - stats.IP.DisabledPacketsReceived.Increment() + stats.DisabledPacketsReceived.Increment() return } // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesPreroutingDropped.Increment() + stats.IPTablesPreroutingDropped.Increment() return } } @@ -601,11 +630,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() - stats := e.protocol.stack.Stats() + stats := e.stats h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.IP.MalformedPacketsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() return } @@ -631,7 +660,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - stats.IP.MalformedPacketsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() return } @@ -643,7 +672,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // be one of its own IP addresses (but not a broadcast or // multicast address). if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.ip.InvalidSourceAddressesReceived.Increment() return } // Make sure the source address is not a subnet-local broadcast address. @@ -651,7 +680,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { subnet := addressEndpoint.Subnet() addressEndpoint.DecRef() if subnet.IsBroadcast(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.ip.InvalidSourceAddressesReceived.Increment() return } } @@ -664,7 +693,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() + stats.ip.InvalidDestinationAddressesReceived.Increment() return } @@ -674,9 +703,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesInputDropped.Increment() + stats.ip.IPTablesInputDropped.Increment() return } @@ -684,8 +714,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -698,8 +728,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } @@ -720,8 +750,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt, ) if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } if !ready { @@ -734,7 +764,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) h.SetFlagsFragmentOffset(0, 0) } - stats.IP.PacketsDelivered.Increment() + stats.ip.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { @@ -755,19 +785,13 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options // rather than just throwing them away. - aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{}) - if err != nil { - switch { - case - errors.Is(err, header.ErrIPv4OptDuplicate), - errors.Is(err, errIPv4RecordRouteOptInvalidPointer), - errors.Is(err, errIPv4RecordRouteOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) - stats.MalformedRcvdPackets.Increment() - stats.IP.MalformedPacketsReceived.Increment() + if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{}); optProblem != nil { + if optProblem.NeedICMP { + _ = e.protocol.returnError(&icmpReasonParamProblem{ + pointer: optProblem.Pointer, + }, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + stats.ip.MalformedPacketsReceived.Increment() } return } @@ -800,10 +824,12 @@ func (e *endpoint) Close() { e.disableLocked() e.mu.addressableEndpointState.Cleanup() + + e.protocol.forgetEndpoint(e.nic.ID()) } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() @@ -815,7 +841,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p } // RemovePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.mu.addressableEndpointState.RemovePermanentAddress(addr) @@ -872,7 +898,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) @@ -881,9 +907,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { // joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error { if !header.IsV4MulticastAddress(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } e.mu.igmp.joinGroup(addr) @@ -891,7 +917,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) @@ -900,7 +926,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error { return e.mu.igmp.leaveGroup(addr) } @@ -911,6 +937,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { return e.mu.igmp.isInGroup(addr) } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -918,6 +949,14 @@ var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack + mu struct { + sync.RWMutex + + // eps is keyed by NICID to allow protocol methods to retrieve an endpoint + // when handling a packet, by looking at which NIC handled the packet. + eps map[tcpip.NICID]*endpoint + } + // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful. // @@ -960,24 +999,24 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: p.SetDefaultTTL(uint8(*v)) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -1023,9 +1062,9 @@ func (p *protocol) SetForwarding(v bool) { // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. -func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { if linkMTU < header.IPv4MinimumMTU { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in @@ -1033,7 +1072,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Erro // The maximal internet header is 60 octets, and a typical internet header // is 20 octets, allowing a margin for headers of higher level protocols. if networkHeaderSize > header.IPv4MaximumHeaderSize { - return 0, tcpip.ErrMalformedHeader + return 0, &tcpip.ErrMalformedHeader{} } networkMTU := linkMTU @@ -1095,6 +1134,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { options: opts, } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + p.mu.eps = make(map[tcpip.NICID]*endpoint) return p } } @@ -1192,16 +1232,9 @@ func (*optionUsageEcho) actions() optionActions { } } -var ( - errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") - errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") - errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") - errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") -) - // handleTimestamp does any required processing on a Timestamp option // in place. -func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) *header.IPv4OptParameterProblem { flags := tsOpt.Flags() var entrySize uint8 switch flags { @@ -1212,7 +1245,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres header.IPv4OptionTimestampWithPredefinedIPFlag: entrySize = header.IPv4OptionTimestampWithAddrSize default: - return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSOFLWAndFLGOffset, + NeedICMP: true, + } } pointer := tsOpt.Pointer() @@ -1220,7 +1256,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres // Since the pointer is 1 based, and the header is 4 bytes long the // pointer must point beyond the header therefore 4 or less is bad. if pointer <= header.IPv4OptionTimestampHdrLength { - return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSPointerOffset, + NeedICMP: true, + } } // To simplify processing below, base further work on the array of timestamps // beyond the header, rather than on the whole option. Also to aid @@ -1254,14 +1293,17 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres // timestamp, but the overflow count is incremented by one. if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { // By definition we have nothing to do. - return 0, nil + return nil } if tsOpt.IncOverflow() != 0 { - return 0, nil + return nil } // The overflow count is also full. - return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSOFLWAndFLGOffset, + NeedICMP: true, + } } if nextSlot+entrySize > dataLength { // The data area isn't full but there isn't room for a new entry. @@ -1280,32 +1322,36 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres if dataLength%entrySize != 0 { // The Data section size should be a multiple of the expected // timestamp entry size. - return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: false, + } } // If the size is OK, the pointer must be corrupted. } - return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSPointerOffset, + NeedICMP: true, + } } if usage.actions().timestamp == optionProcess { tsOpt.UpdateTimestamp(localAddress, clock) } - return 0, nil + return nil } -var ( - errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") - errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") -) - // handleRecordRoute checks and processes a Record route option. It is much // like the timestamp type 1 option, but without timestamps. The passed in // address is stored in the option in the correct spot if possible. -func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) *header.IPv4OptParameterProblem { optlen := rrOpt.Size() if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { - return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: true, + } } pointer := rrOpt.Pointer() @@ -1315,7 +1361,10 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // Since the pointer is 1 based, and the header is 3 bytes long the // pointer must point beyond the header therefore 3 or less is bad. if pointer <= header.IPv4OptionRecordRouteHdrLength { - return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptRRPointerOffset, + NeedICMP: true, + } } // RFC 791 page 21 says @@ -1332,7 +1381,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // of these words is a copy/paste error from the timestamp option where // there are two failure reasons given. if pointer > optlen { - return 0, nil + return nil } // The data area isn't full but there isn't room for a new entry. @@ -1357,17 +1406,23 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // } if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { // Length is bad, not on integral number of slots. - return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: true, + } } // If not length, the fault must be with the pointer. } - return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptRRPointerOffset, + NeedICMP: true, + } } if usage.actions().recordRoute == optionVerify { - return 0, nil + return nil } rrOpt.StoreAddress(localAddress) - return 0, nil + return nil } // processIPOptions parses the IPv4 options and produces a new set of options @@ -1378,8 +1433,8 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // - The location of an error if there was one (or 0 if no error) // - If there is an error, information as to what it was was. // - The replacement option set. -func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { - stats := e.protocol.stack.Stats() +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, *header.IPv4OptParameterProblem) { + stats := e.stats.ip opts := header.IPv4Options(orig) optIter := opts.MakeIterator() @@ -1392,21 +1447,23 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt // This will need tweaking when we start really forwarding packets // as we may need to get two addresses, for rx and tx interfaces. // We will also have to take usage into account. - prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) + prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) localAddress := prefixedAddress.Address - if err != nil { + if !ok { h := header.IPv4(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { - return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + return nil, &header.IPv4OptParameterProblem{ + NeedICMP: false, + } } localAddress = dstAddr } for { - option, done, err := optIter.Next() - if done || err != nil { - return optIter.ErrCursor, optIter.Finalize(), err + option, done, optProblem := optIter.Next() + if done || optProblem != nil { + return optIter.Finalize(), optProblem } optType := option.Type() if optType == header.IPv4OptionNOPType { @@ -1415,44 +1472,47 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt } if optType == header.IPv4OptionListEndType { optIter.PushNOPOrEnd(optType) - return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + return optIter.Finalize(), nil } // check for repeating options (multiple NOPs are OK) if seenOptions[optType] { - return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + return nil, &header.IPv4OptParameterProblem{ + Pointer: optIter.ErrCursor, + NeedICMP: true, + } } seenOptions[optType] = true optLen := int(option.Size()) switch option := option.(type) { case *header.IPv4OptionTimestamp: - stats.IP.OptionTSReceived.Increment() + stats.OptionTSReceived.Increment() if usage.actions().timestamp != optionRemove { clock := e.protocol.stack.Clock() newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) - offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) - if err != nil { - return optIter.ErrCursor + offset, nil, err + if optProblem := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage); optProblem != nil { + optProblem.Pointer += optIter.ErrCursor + return nil, optProblem } optIter.ConsumeBuffer(optLen) } case *header.IPv4OptionRecordRoute: - stats.IP.OptionRRReceived.Increment() + stats.OptionRRReceived.Increment() if usage.actions().recordRoute != optionRemove { newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) - offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) - if err != nil { - return optIter.ErrCursor + offset, nil, err + if optProblem := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage); optProblem != nil { + optProblem.Pointer += optIter.ErrCursor + return nil, optProblem } optIter.ConsumeBuffer(optLen) } default: - stats.IP.OptionUnknownReceived.Increment() + stats.OptionUnknownReceived.Increment() if usage.actions().unknown == optionPass { newBuffer := optIter.RemainingBuffer()[:optLen] // Arguments already heavily checked.. ignore result. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index a9e137c24..ed5899f0b 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -78,8 +78,11 @@ func TestExcludeBroadcast(t *testing.T) { defer ep.Close() // Cannot connect using a broadcast address as the source. - if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + { + err := ep.Connect(randomAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{}) + } } // However, we can bind to a broadcast address to listen. @@ -270,6 +273,11 @@ func TestIPv4Sanity(t *testing.T) { nicID = 1 randomSequence = 123 randomIdent = 42 + // In some cases Linux sets the error pointer to the start of the option + // (offset 0) instead of the actual wrong value, which is the length byte + // (offset 1). For compatibility we must do the same. Use this constant + // to indicate where this happens. + pointerOffsetForInvalidLength = 0 ) var ( ipv4Addr = tcpip.AddressWithPrefix{ @@ -439,6 +447,21 @@ func TestIPv4Sanity(t *testing.T) { replyOptions: header.IPv4Options{}, }, { + name: "bad option - no length", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 1, 1, 1, 68, + // ^-start of timestamp.. but no length.. + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + }, + { name: "bad option - length 0", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), @@ -448,7 +471,27 @@ func TestIPv4Sanity(t *testing.T) { // ^ 1, 2, 3, 4, }, - shouldFail: true, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, + }, + { + name: "bad option - length 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 1, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "bad option - length big", @@ -462,7 +505,11 @@ func TestIPv4Sanity(t *testing.T) { // space is not possible. (Second byte) 1, 2, 3, 4, }, - shouldFail: true, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { // This tests for some linux compatible behaviour. @@ -484,7 +531,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, { name: "multiple type 0 with room", @@ -589,7 +636,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, { name: "valid timestamp pointer", @@ -624,7 +671,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, // End of option list with illegal option after it, which should be ignored. { @@ -636,24 +683,31 @@ func TestIPv4Sanity(t *testing.T) { 68, 12, 13, 0x11, 192, 168, 1, 12, 1, 2, 3, 4, - 0, 10, 3, 99, + 0, 10, 3, 99, // EOL followed by junk }, replyOptions: header.IPv4Options{ 68, 12, 13, 0x21, 192, 168, 1, 12, 1, 2, 3, 4, - 0, 0, 0, 0, // 3 bytes unknown option - }, // ^ End of options hides following bytes. + 0, // End of Options hides following bytes. + 0, 0, 0, // 3 bytes unknown option removed. + }, }, { - // Timestamp with a size too small. + // Timestamp with a size much too small. name: "timestamp truncated", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, - options: header.IPv4Options{68, 1, 0, 0}, - // ^ Smallest possible is 8. - shouldFail: true, + options: header.IPv4Options{ + 68, 1, 0, 0, + // ^ Smallest possible is 8. Linux points at the 68. + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "single record route with room", @@ -751,7 +805,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { // Pointer must be 4 or more as it must point past the 3 byte header @@ -769,7 +823,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { // Pointer must be 4 or more as it must point past the 3 byte header @@ -808,8 +862,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, - replyOptions: header.IPv4Options{}, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { name: "duplicate record route", @@ -828,7 +881,6 @@ func TestIPv4Sanity(t *testing.T) { ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + 7, - replyOptions: header.IPv4Options{}, }, } @@ -884,7 +936,6 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } - ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: test.transportProtocol, @@ -1328,8 +1379,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize int allowPackets int outgoingErrors int - mockError *tcpip.Error - wantError *tcpip.Error + mockError tcpip.Error + wantError tcpip.Error }{ { description: "No frag", @@ -1338,8 +1389,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 0, outgoingErrors: 1, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag", @@ -1348,8 +1399,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 0, outgoingErrors: 3, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on second frag", @@ -1358,8 +1409,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 1, outgoingErrors: 2, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag MTU smaller than header", @@ -1368,8 +1419,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize: 500, allowPackets: 0, outgoingErrors: 4, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error when MTU is smaller than IPv4 minimum MTU", @@ -1379,7 +1430,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrInvalidEndpointState, + wantError: &tcpip.ErrInvalidEndpointState{}, }, } @@ -1393,8 +1444,8 @@ func TestFragmentationErrors(t *testing.T) { TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.wantError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + if diff := cmp.Diff(ft.wantError, err); diff != "" { + t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff) } if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) @@ -2427,8 +2478,9 @@ func TestReceiveFragments(t *testing.T) { } } - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -2506,11 +2558,11 @@ func TestWriteStats(t *testing.T) { // Parameterize the tests to run with both WritePacket and WritePackets. writers := []struct { name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) }{ { name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { @@ -2522,7 +2574,7 @@ func TestWriteStats(t *testing.T) { }, }, { name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) }, }, @@ -2532,7 +2584,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -2608,7 +2660,7 @@ func (*limitedMatcher) Name() string { } // Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go new file mode 100644 index 000000000..bee72c649 --- /dev/null +++ b/pkg/tcpip/network/ipv4/stats.go @@ -0,0 +1,190 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.IPNetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to the IPv4 protocol family. +type Stats struct { + // IP holds IPv4 statistics. + IP tcpip.IPStats + + // IGMP holds IGMP statistics. + IGMP tcpip.IGMPStats + + // ICMP holds ICMPv4 statistics. + ICMP tcpip.ICMPv4Stats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*Stats) IsNetworkEndpointStats() {} + +// IPStats implements stack.IPNetworkEndointStats +func (s *Stats) IPStats() *tcpip.IPStats { + return &s.IP +} + +type sharedStats struct { + localStats Stats + ip ip.MultiCounterIPStats + icmp multiCounterICMPv4Stats + igmp multiCounterIGMPStats +} + +// LINT.IfChange(multiCounterICMPv4PacketStats) + +type multiCounterICMPv4PacketStats struct { + echo tcpip.MultiCounterStat + echoReply tcpip.MultiCounterStat + dstUnreachable tcpip.MultiCounterStat + srcQuench tcpip.MultiCounterStat + redirect tcpip.MultiCounterStat + timeExceeded tcpip.MultiCounterStat + paramProblem tcpip.MultiCounterStat + timestamp tcpip.MultiCounterStat + timestampReply tcpip.MultiCounterStat + infoRequest tcpip.MultiCounterStat + infoReply tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) { + m.echo.Init(a.Echo, b.Echo) + m.echoReply.Init(a.EchoReply, b.EchoReply) + m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable) + m.srcQuench.Init(a.SrcQuench, b.SrcQuench) + m.redirect.Init(a.Redirect, b.Redirect) + m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded) + m.paramProblem.Init(a.ParamProblem, b.ParamProblem) + m.timestamp.Init(a.Timestamp, b.Timestamp) + m.timestampReply.Init(a.TimestampReply, b.TimestampReply) + m.infoRequest.Init(a.InfoRequest, b.InfoRequest) + m.infoReply.Init(a.InfoReply, b.InfoReply) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4PacketStats) + +// LINT.IfChange(multiCounterICMPv4SentPacketStats) + +type multiCounterICMPv4SentPacketStats struct { + multiCounterICMPv4PacketStats + dropped tcpip.MultiCounterStat + rateLimited tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4SentPacketStats) init(a, b *tcpip.ICMPv4SentPacketStats) { + m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats) + m.dropped.Init(a.Dropped, b.Dropped) + m.rateLimited.Init(a.RateLimited, b.RateLimited) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4SentPacketStats) + +// LINT.IfChange(multiCounterICMPv4ReceivedPacketStats) + +type multiCounterICMPv4ReceivedPacketStats struct { + multiCounterICMPv4PacketStats + invalid tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4ReceivedPacketStats) init(a, b *tcpip.ICMPv4ReceivedPacketStats) { + m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats) + m.invalid.Init(a.Invalid, b.Invalid) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4ReceivedPacketStats) + +// LINT.IfChange(multiCounterICMPv4Stats) + +type multiCounterICMPv4Stats struct { + packetsSent multiCounterICMPv4SentPacketStats + packetsReceived multiCounterICMPv4ReceivedPacketStats +} + +func (m *multiCounterICMPv4Stats) init(a, b *tcpip.ICMPv4Stats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4Stats) + +// LINT.IfChange(multiCounterIGMPPacketStats) + +type multiCounterIGMPPacketStats struct { + membershipQuery tcpip.MultiCounterStat + v1MembershipReport tcpip.MultiCounterStat + v2MembershipReport tcpip.MultiCounterStat + leaveGroup tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPPacketStats) init(a, b *tcpip.IGMPPacketStats) { + m.membershipQuery.Init(a.MembershipQuery, b.MembershipQuery) + m.v1MembershipReport.Init(a.V1MembershipReport, b.V1MembershipReport) + m.v2MembershipReport.Init(a.V2MembershipReport, b.V2MembershipReport) + m.leaveGroup.Init(a.LeaveGroup, b.LeaveGroup) +} + +// LINT.ThenChange(../../tcpip.go:IGMPPacketStats) + +// LINT.IfChange(multiCounterIGMPSentPacketStats) + +type multiCounterIGMPSentPacketStats struct { + multiCounterIGMPPacketStats + dropped tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPSentPacketStats) init(a, b *tcpip.IGMPSentPacketStats) { + m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats) + m.dropped.Init(a.Dropped, b.Dropped) +} + +// LINT.ThenChange(../../tcpip.go:IGMPSentPacketStats) + +// LINT.IfChange(multiCounterIGMPReceivedPacketStats) + +type multiCounterIGMPReceivedPacketStats struct { + multiCounterIGMPPacketStats + invalid tcpip.MultiCounterStat + checksumErrors tcpip.MultiCounterStat + unrecognized tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPReceivedPacketStats) init(a, b *tcpip.IGMPReceivedPacketStats) { + m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats) + m.invalid.Init(a.Invalid, b.Invalid) + m.checksumErrors.Init(a.ChecksumErrors, b.ChecksumErrors) + m.unrecognized.Init(a.Unrecognized, b.Unrecognized) +} + +// LINT.ThenChange(../../tcpip.go:IGMPReceivedPacketStats) + +// LINT.IfChange(multiCounterIGMPStats) + +type multiCounterIGMPStats struct { + packetsSent multiCounterIGMPSentPacketStats + packetsReceived multiCounterIGMPReceivedPacketStats +} + +func (m *multiCounterIGMPStats) init(a, b *tcpip.IGMPStats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:IGMPStats) diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go new file mode 100644 index 000000000..b28e7dcde --- /dev/null +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.NetworkInterface + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func knownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + nic := testInterface{nicID: 1} + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) + } + + ep.Close() + + proto.mu.Lock() + _, hasEP := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } +} + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // expected to be bound by a MultiCounterStat. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index afa45aefe..0c5f8d683 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -10,6 +10,7 @@ go_library( "ipv6.go", "mld.go", "ndp.go", + "stats.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6ee162713..7298bd061 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -125,15 +125,14 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { - stats := e.protocol.stack.Stats().ICMP - sent := stats.V6.PacketsSent - received := stats.V6.PacketsReceived + sent := e.stats.icmp.packetsSent + received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.ICMPv6(v) @@ -147,7 +146,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { payload := pkt.Data.Clone(nil) payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -165,10 +164,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // TODO(b/112892170): Meaningfully handle all ICMP types. switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: - received.PacketTooBig.Increment() + received.packetTooBig.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) @@ -179,10 +178,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: - received.DstUnreachable.Increment() + received.dstUnreachable.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) @@ -194,9 +193,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6NeighborSolicit: - received.NeighborSolicit.Increment() + received.neighborSolicit.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -210,7 +209,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast // address. if header.IsV6MulticastAddress(targetAddr) { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -238,7 +237,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: + default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) } } @@ -263,13 +264,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { if err != nil { // Options are not valid as per the wire format, silently drop the // packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok = getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } } @@ -282,16 +283,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { unspecifiedSource := srcAddr == header.IPv6Any if len(sourceLinkAddr) == 0 { if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource { - received.Invalid.Increment() + received.invalid.Increment() return } } else if unspecifiedSource { - received.Invalid.Increment() + received.invalid.Increment() return } else if e.nud != nil { e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr) } // As per RFC 4861 section 7.1.1: @@ -301,7 +302,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // - If the IP source address is the unspecified address, the IP // destination address is a solicited-node multicast address. if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -379,15 +380,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.NeighborAdvert.Increment() + sent.neighborAdvert.Increment() case header.ICMPv6NeighborAdvert: - received.NeighborAdvert.Increment() + received.neighborAdvert.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -414,16 +415,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: + return + default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) } - return } it, err := na.Options().Iter(false /* check */) if err != nil { // If we have a malformed NDP NA option, drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } @@ -438,7 +441,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // DAD. targetLinkAddr, ok := getTargetLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() + return + } + + // As per RFC 4861 section 7.1.2: + // A node MUST silently discard any received Neighbor Advertisement + // messages that do not satisfy all of the following validity checks: + // ... + // - If the IP Destination Address is a multicast address the + // Solicited flag is zero. + if header.IsV6MulticastAddress(dstAddr) && na.SolicitedFlag() { + received.invalid.Increment() return } @@ -446,7 +460,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // address cache with the link address for the target of the message. if e.nud == nil { if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) + e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr) } return } @@ -458,10 +472,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { }) case header.ICMPv6EchoRequest: - received.EchoRequest.Increment() + received.echoRequest.Increment() icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -493,27 +507,27 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, }, replyPkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.EchoReply.Increment() + sent.echoReply.Increment() case header.ICMPv6EchoReply: - received.EchoReply.Increment() + received.echoReply.Increment() if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: - received.TimeExceeded.Increment() + received.timeExceeded.Increment() case header.ICMPv6ParamProblem: - received.ParamProblem.Increment() + received.paramProblem.Increment() case header.ICMPv6RouterSolicit: - received.RouterSolicit.Increment() + received.routerSolicit.Increment() // // Validate the RS as per RFC 4861 section 6.1.1. @@ -521,7 +535,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the NDP payload of sufficient size to hold a Router Solictation? if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -530,7 +544,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the networking stack operating as a router? if !stack.Forwarding(ProtocolNumber) { // ... No, silently drop the packet. - received.RouterOnlyPacketsDroppedByHost.Increment() + received.routerOnlyPacketsDroppedByHost.Increment() return } @@ -540,13 +554,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { it, err := rs.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok := getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -557,7 +571,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // NOT be included when the source IP address is the unspecified address. // Otherwise, it SHOULD be included on link layers that have addresses. if srcAddr == header.IPv6Any { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -569,7 +583,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6RouterAdvert: - received.RouterAdvert.Increment() + received.routerAdvert.Increment() // // Validate the RA as per RFC 4861 section 6.1.2. @@ -577,7 +591,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the NDP payload of sufficient size to hold a Router Advertisement? if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -586,7 +600,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { // ...No, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } @@ -596,13 +610,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { it, err := ra.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok := getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -638,26 +652,26 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // link-layer address be modified due to receiving one of the above // messages, the state SHOULD also be set to STALE to provide prompt // verification that the path to the new link-layer address is working." - received.RedirectMsg.Increment() + received.redirectMsg.Increment() if !isNDPValid() { - received.Invalid.Increment() + received.invalid.Increment() return } case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: switch icmpType { case header.ICMPv6MulticastListenerQuery: - received.MulticastListenerQuery.Increment() + received.multicastListenerQuery.Increment() case header.ICMPv6MulticastListenerReport: - received.MulticastListenerReport.Increment() + received.multicastListenerReport.Increment() case header.ICMPv6MulticastListenerDone: - received.MulticastListenerDone.Increment() + received.multicastListenerDone.Increment() default: panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -676,7 +690,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } default: - received.Unrecognized.Increment() + received.unrecognized.Increment() } } @@ -688,26 +702,39 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { + nicID := nic.ID() + + p.mu.Lock() + netEP, ok := p.mu.eps[nicID] + p.mu.Unlock() + if !ok { + return &tcpip.ErrNotConnected{} + } + remoteAddr := targetAddr if len(remoteLinkAddr) == 0 { remoteAddr = header.SolicitedNodeAddr(targetAddr) remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + if len(localAddr) == 0 { + addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) + if addressEndpoint == nil { + return &tcpip.ErrNetworkUnreachable{} + } + + localAddr = addressEndpoint.AddressWithPrefix().Address + } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 { + return &tcpip.ErrBadLocalAddress{} } - defer r.Release() - r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) @@ -715,18 +742,23 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot ns := header.NDPNeighborSolicit(packet.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) - packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, localAddr, remoteAddr, buffer.VectorisedView{})) - stat := p.stack.Stats().ICMP.V6.PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, remoteAddr, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - }, pkt); err != nil { - stat.Dropped.Increment() + }, header.IPv6ExtHdrSerializer{}); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } + + stat := netEP.stats.icmp.packetsSent + + if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + stat.dropped.Increment() return err } - stat.NeighborSolicit.Increment() + stat.neighborSolicit.Increment() return nil } @@ -796,7 +828,7 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { origIPHdr := header.IPv6(pkt.NetworkHeader().View()) origIPHdrSrc := origIPHdr.SourceAddress() origIPHdrDst := origIPHdr.DestinationAddress() @@ -863,10 +895,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - stats := p.stack.Stats().ICMP - sent := stats.V6.PacketsSent + p.mu.Lock() + netEP, ok := p.mu.eps[pkt.NICID] + p.mu.Unlock() + if !ok { + return &tcpip.ErrNotConnected{} + } + + sent := netEP.stats.icmp.packetsSent + if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return nil } @@ -897,11 +936,11 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi // the error message packet exceed the minimum IPv6 MTU // [IPv6]. mtu := int(route.MTU()) - if mtu > header.IPv6MinimumMTU { - mtu = header.IPv6MinimumMTU + const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize + if mtu > maxIPv6Data { + mtu = maxIPv6Data } - headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize - available := int(mtu) - headerLen + available := mtu - header.ICMPv6ErrorHeaderSize if available < header.IPv6MinimumSize { return nil } @@ -915,31 +954,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, + ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize, Data: payload, }) newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - var counter *tcpip.StatCounter + var counter tcpip.MultiCounterStat switch reason := reason.(type) { case *icmpReasonParameterProblem: icmpHdr.SetType(header.ICMPv6ParamProblem) icmpHdr.SetCode(reason.code) icmpHdr.SetTypeSpecific(reason.pointer) - counter = sent.ParamProblem + counter = sent.paramProblem case *icmpReasonPortUnreachable: icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonHopLimitExceeded: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) - counter = sent.TimeExceeded + counter = sent.timeExceeded default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } @@ -953,7 +992,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi }, newPkt, ); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } counter.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index b1e6a70a2..db1c2e663 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "bytes" "context" "net" "reflect" @@ -22,6 +23,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -77,7 +79,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { return nil } @@ -91,16 +93,11 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st return stack.TransportPacketHandled } -type stubLinkAddressCache struct { - stack.LinkAddressCache -} +var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil) -func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID { - return 0 -} +type stubLinkAddressCache struct{} -func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { -} +func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {} type stubNUDHandler struct { probeCount int @@ -148,15 +145,15 @@ func (*testInterface) Promiscuous() bool { return false } -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) } -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) } -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var r stack.RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr @@ -643,7 +640,6 @@ func TestLinkResolution(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - payload := tcpip.SlicePayload(hdr.View()) // We can't send our payload directly over the route because that // doesn't provoke NDP discovery. @@ -653,8 +649,12 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } - if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { - t.Fatalf("ep.Write(_): %s", err) + { + var r bytes.Reader + r.Reset(hdr.View()) + if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { + t.Fatalf("ep.Write(_): %s", err) + } } for _, args := range []routeArgs{ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, @@ -1283,7 +1283,7 @@ func TestLinkAddressRequest(t *testing.T) { localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectedErr *tcpip.Error + expectedErr tcpip.Error expectedRemoteAddr tcpip.Address expectedRemoteLinkAddr tcpip.LinkAddress }{ @@ -1321,79 +1321,80 @@ func TestLinkAddressRequest(t *testing.T) { name: "Unicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: linkAddr1, - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrBadLocalAddress{}, }, { name: "Multicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrBadLocalAddress{}, }, { name: "Unicast with no local address available", remoteLinkAddr: linkAddr1, - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrNetworkUnreachable{}, }, { name: "Multicast with no local address available", remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrNetworkUnreachable{}, }, } for _, test := range tests { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - p := s.NetworkProtocolInstance(ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + p := s.NetworkProtocolInstance(ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := s.CreateNIC(nicID, linkEP); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } } - } - // We pass a test network interface to LinkAddressRequest with the same NIC - // ID and link endpoint used by the NIC we created earlier so that we can - // mock a link address request and observe the packets sent to the link - // endpoint even though the stack uses the real NIC. - if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) - } + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) + } - if test.expectedErr != nil { - return - } + if test.expectedErr != nil { + return + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) - } - if pkt.Route.RemoteAddress != test.expectedRemoteAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) - } - if pkt.Route.LocalAddress != lladdr1 { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) - } - checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), - checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedRemoteAddr), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), - )) + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = test.expectedRemoteLinkAddr + if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) + } + checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), + checker.SrcAddr(lladdr1), + checker.DstAddr(test.expectedRemoteAddr), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), + )) + }) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index ae4a8f508..94043ed4e 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -20,6 +20,7 @@ import ( "fmt" "hash/fnv" "math" + "reflect" "sort" "sync/atomic" "time" @@ -177,6 +178,7 @@ type endpoint struct { dispatcher stack.TransportDispatcher protocol *protocol stack *stack.Stack + stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. @@ -305,17 +307,17 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { // dupTentativeAddrDetected removes the tentative address if it exists. If the // address was generated via SLAAC, an attempt is made to generate a new // address. -func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } if addressEndpoint.GetKind() != stack.PermanentTentative { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an @@ -367,14 +369,14 @@ func (e *endpoint) transitionForwarding(forwarding bool) { } // Enable implements stack.NetworkEndpoint. -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If the NIC is not enabled, the endpoint can't do anything meaningful so // don't enable the endpoint. if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } // If the endpoint is already enabled, there is nothing for it to do. @@ -416,7 +418,7 @@ func (e *endpoint) Enable() *tcpip.Error { // // Addresses may have aleady completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. - var err *tcpip.Error + var err tcpip.Error e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { addr := addressEndpoint.AddressWithPrefix().Address if !header.IsV6UnicastAddress(addr) { @@ -497,7 +499,9 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } @@ -553,11 +557,11 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error { +func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) tcpip.Error { extHdrsLen := extensionHeaders.Length() length := pkt.Size() + extensionHeaders.Length() if length > math.MaxUint16 { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ @@ -583,7 +587,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta // fragments left to be processed. The IP header must already be present in the // original packet. The transport header protocol number is required to avoid // parsing the IPv6 extension headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { networkHeader := header.IPv6(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are @@ -596,13 +600,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // of 8 as per RFC 8200 section 4.5: // Each complete fragment, except possibly the last ("rightmost") one, is // an integer multiple of 8 octets long. - return 0, 1, tcpip.ErrMessageTooLong + return 0, 1, &tcpip.ErrMessageTooLong{} } if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { // As per RFC 8200 Section 4.5, the Transport Header is expected to be small // enough to fit in the first fragment. - return 0, 1, tcpip.ErrMessageTooLong + return 0, 1, &tcpip.ErrMessageTooLong{} } pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) @@ -622,17 +626,17 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { + if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { return err } // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. - e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() + e.stats.ip.IPTablesOutputDropped.Increment() return nil } @@ -660,7 +664,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { @@ -675,36 +679,37 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) }) - r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + stats.PacketsSent.IncrementBy(uint64(sent)) + stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } - r.Stats().IP.PacketsSent.Increment() + stats.PacketsSent.Increment() return nil } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } @@ -712,28 +717,29 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + stats := e.stats.ip linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { + if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { return 0, err } networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } if packetMustBeFragmented(pb, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt return nil }); err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } // Remove the packet that was just fragmented and process the rest. @@ -743,19 +749,19 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } return n, err } - r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. @@ -779,8 +785,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) + stats.PacketsSent.IncrementBy(uint64(n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -788,17 +794,17 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe n++ } - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) // Dropped packets aren't errors, so include them in the return value. return n + len(dropped), nil } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required checks. h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } ip := header.IPv6(h) @@ -823,14 +829,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // sending the packet. proto, _, _, _, ok := parse.IPv6(pkt) if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv6(pkt.NetworkHeader().View()) hopLimit := h.HopLimit() if hopLimit <= 1 { @@ -852,7 +858,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { networkEndpoint.(*endpoint).handlePacket(pkt) return nil } - if err != tcpip.ErrBadAddress { + if _, ok := err.(*tcpip.ErrBadAddress); !ok { return err } @@ -882,19 +888,21 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - stats.IP.PacketsReceived.Increment() + stats := e.stats.ip + + stats.PacketsReceived.Increment() if !e.isEnabled() { - stats.IP.DisabledPacketsReceived.Increment() + stats.DisabledPacketsReceived.Increment() return } // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesPreroutingDropped.Increment() + stats.IPTablesPreroutingDropped.Increment() return } } @@ -906,11 +914,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() - stats := e.protocol.stack.Stats() + stats := e.stats.ip h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } srcAddr := h.SourceAddress() @@ -920,7 +928,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Multicast addresses must not be used as source addresses in IPv6 // packets or appear in any Routing header. if header.IsV6MulticastAddress(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.InvalidSourceAddressesReceived.Increment() return } @@ -930,7 +938,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { addressEndpoint.DecRef() } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() + stats.InvalidDestinationAddressesReceived.Increment() return } @@ -950,9 +958,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesInputDropped.Increment() + stats.IPTablesInputDropped.Increment() return } @@ -962,7 +971,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -986,7 +995,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -1075,8 +1084,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { it, done, err := it.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } if done { @@ -1103,8 +1112,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { switch lastHdr.(type) { case header.IPv6RawPayloadHeader: default: - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } } @@ -1112,8 +1121,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } @@ -1126,8 +1135,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // of the fragment, pointing to the Payload Length field of the // fragment packet. if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, @@ -1147,8 +1156,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // the fragment, pointing to the Fragment Offset field of the fragment // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, @@ -1173,8 +1182,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt, ) if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } @@ -1194,7 +1203,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -1244,12 +1253,12 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf - stats.IP.PacketsDelivered.Increment() + stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p e.handleICMP(pkt, hasFragmentHeader) } else { - stats.IP.PacketsDelivered.Increment() + stats.PacketsDelivered.Increment() switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: @@ -1314,7 +1323,7 @@ func (e *endpoint) Close() { e.mu.addressableEndpointState.Cleanup() e.mu.Unlock() - e.protocol.forgetEndpoint(e) + e.protocol.forgetEndpoint(e.nic.ID()) } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -1323,7 +1332,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. e.mu.Lock() @@ -1338,7 +1347,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // solicited-node multicast group and start duplicate address detection. // // Precondition: e.mu must be write locked. -func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) if err != nil { return nil, err @@ -1367,13 +1376,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } // RemovePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } return e.removePermanentEndpointLocked(addressEndpoint, true) @@ -1383,7 +1392,7 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { // it works with a stack.AddressEndpoint. // // Precondition: e.mu must be write locked. -func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error { +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) tcpip.Error { addr := addressEndpoint.AddressWithPrefix() unicast := header.IsV6UnicastAddress(addr.Address) if unicast { @@ -1408,12 +1417,12 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) + err := e.leaveGroupLocked(snmc) // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { - return err + if _, ok := err.(*tcpip.ErrBadLocalAddress); ok { + err = nil } - - return nil + return err } // hasPermanentAddressLocked returns true if the endpoint has a permanent @@ -1623,7 +1632,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) @@ -1632,9 +1641,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { // joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } e.mu.mld.joinGroup(addr) @@ -1642,7 +1651,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) @@ -1651,7 +1660,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error { return e.mu.mld.leaveGroup(addr) } @@ -1662,6 +1671,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { return e.mu.mld.isInGroup(addr) } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1673,7 +1687,9 @@ type protocol struct { mu struct { sync.RWMutex - eps map[*endpoint]struct{} + // eps is keyed by NICID to allow protocol methods to retrieve an endpoint + // when handling a packet, by looking at which NIC handled the packet. + eps map[tcpip.NICID]*endpoint } ids []uint32 @@ -1730,37 +1746,42 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L e.mu.mld.init(e) e.mu.Unlock() + stackStats := p.stack.Stats() + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP) + e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V6) + p.mu.Lock() defer p.mu.Unlock() - p.mu.eps[e] = struct{}{} + p.mu.eps[nic.ID()] = e return e } -func (p *protocol) forgetEndpoint(e *endpoint) { +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() - delete(p.mu.eps, e) + delete(p.mu.eps, nicID) } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: p.SetDefaultTTL(uint8(*v)) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -1814,7 +1835,7 @@ func (p *protocol) SetForwarding(v bool) { return } - for ep := range p.mu.eps { + for _, ep := range p.mu.eps { ep.transitionForwarding(v) } } @@ -1823,9 +1844,9 @@ func (p *protocol) SetForwarding(v bool) { // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, // which includes the length of the extension headers. -func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error) { if linkMTU < header.IPv6MinimumMTU { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } // As per RFC 7112 section 5, we should discard packets if their IPv6 header @@ -1836,7 +1857,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Erro // bytes ensures that the header chain length does not exceed the IPv6 // minimum MTU. if networkHeadersLen > header.IPv6MinimumMTU { - return 0, tcpip.ErrMalformedHeader + return 0, &tcpip.ErrMalformedHeader{} } networkMTU := linkMTU - uint32(networkHeadersLen) @@ -1906,7 +1927,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { hashIV: hashIV, } p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) - p.mu.eps = make(map[*endpoint]struct{}) + p.mu.eps = make(map[tcpip.NICID]*endpoint) p.SetDefaultTTL(DefaultTTL) return p } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index b65c9d060..8248052a3 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -21,6 +21,7 @@ import ( "io/ioutil" "math" "net" + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -370,12 +371,10 @@ func TestAddIpv6Address(t *testing.T) { t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) } - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok { + t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber) + } else if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr) } }) } @@ -997,8 +996,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { } // Should not have any more UDP packets. - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -1989,8 +1989,9 @@ func TestReceiveIPv6Fragments(t *testing.T) { } } - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -2473,11 +2474,11 @@ func TestWriteStats(t *testing.T) { writers := []struct { name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) }{ { name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { @@ -2489,7 +2490,7 @@ func TestWriteStats(t *testing.T) { }, }, { name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) }, }, @@ -2499,7 +2500,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -2574,7 +2575,7 @@ func (*limitedMatcher) Name() string { } // Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } @@ -2582,30 +2583,44 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, return false, false } +func knownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + func TestClearEndpointFromProtocolOnClose(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint) - { - proto.mu.Lock() - _, hasEP := proto.mu.eps[ep] - proto.mu.Unlock() - if !hasEP { - t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep) - } + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) } ep.Close() - { - proto.mu.Lock() - _, hasEP := proto.mu.eps[ep] - proto.mu.Unlock() - if hasEP { - t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep) - } + proto.mu.Lock() + _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEndpointAfterClose { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) } } @@ -2819,8 +2834,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize int allowPackets int outgoingErrors int - mockError *tcpip.Error - wantError *tcpip.Error + mockError tcpip.Error + wantError tcpip.Error }{ { description: "No frag", @@ -2829,8 +2844,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 0, outgoingErrors: 1, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag", @@ -2839,8 +2854,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 0, outgoingErrors: 3, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on second frag", @@ -2849,8 +2864,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 1, outgoingErrors: 2, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error when MTU is smaller than transport header", @@ -2860,7 +2875,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrMessageTooLong, + wantError: &tcpip.ErrMessageTooLong{}, }, { description: "Error when MTU is smaller than IPv6 minimum MTU", @@ -2870,7 +2885,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrInvalidEndpointState, + wantError: &tcpip.ErrInvalidEndpointState{}, }, } @@ -2884,8 +2899,8 @@ func TestFragmentationErrors(t *testing.T) { TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.wantError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + if diff := cmp.Diff(ft.wantError, err); diff != "" { + t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff) } if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) @@ -3053,3 +3068,22 @@ func TestForwarding(t *testing.T) { }) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // supposed to be bound. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index ec54d88cc..2cc0dfebd 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -68,14 +68,14 @@ func (mld *mldState) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) } // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error { _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) return err } @@ -112,7 +112,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { // joinGroup handles joining a new group and sending and scheduling the required // messages. // -// If the group is already joined, returns tcpip.ErrDuplicateAddress. +// If the group is already joined, returns *tcpip.ErrDuplicateAddress. // // Precondition: mld.ep.mu must be locked. func (mld *mldState) joinGroup(groupAddress tcpip.Address) { @@ -131,13 +131,13 @@ func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { // required. // // Precondition: mld.ep.mu must be locked. -func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) tcpip.Error { // LeaveGroup returns false only if the group was not joined. if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } // softLeaveAll leaves all groups from the perspective of MLD, but remains @@ -166,14 +166,14 @@ func (mld *mldState) sendQueuedReports() { // writePacket assembles and sends an MLD packet. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { - sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - var mldStat *tcpip.StatCounter +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, tcpip.Error) { + sentStats := mld.ep.stats.icmp.packetsSent + var mldStat tcpip.MultiCounterStat switch mldType { case header.ICMPv6MulticastListenerReport: - mldStat = sentStats.MulticastListenerReport + mldStat = sentStats.multicastListenerReport case header.ICMPv6MulticastListenerDone: - mldStat = sentStats.MulticastListenerDone + mldStat = sentStats.multicastListenerDone default: panic(fmt.Sprintf("unrecognized mld type = %d", mldType)) } @@ -249,14 +249,14 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp Data: buffer.View(icmp).ToVectorisedView(), }) - if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.MLDHopLimit, }, extensionHeaders); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sentStats.Dropped.Increment() + sentStats.dropped.Increment() return false, err } mldStat.Increment() diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 1d8fee50b..d7dde1767 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -241,7 +241,7 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) // OnDefaultRouterDiscovered is called when a new default router is // discovered. Implementations must return true if the newly discovered @@ -614,10 +614,10 @@ type slaacPrefixState struct { // tentative. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { - return tcpip.ErrAddressFamilyNotSupported + return &tcpip.ErrAddressFamilyNotSupported{} } if addressEndpoint.GetKind() != stack.PermanentTentative { @@ -666,7 +666,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE dadDone := remaining == 0 - var err *tcpip.Error + var err tcpip.Error if !dadDone { err = ndp.sendDADPacket(addr, addressEndpoint) } @@ -717,7 +717,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // addr. // // addr must be a tentative IPv6 address on ndp's IPv6 endpoint. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error { snmc := header.SolicitedNodeAddr(addr) icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) @@ -731,8 +731,8 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add Data: buffer.View(icmp).ToVectorisedView(), }) - sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + sent := ndp.ep.stats.icmp.packetsSent + if err := addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { @@ -740,10 +740,11 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add } if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } - sent.NeighborSolicit.Increment() + sent.neighborSolicit.Increment() + return nil } @@ -1855,20 +1856,20 @@ func (ndp *ndpState) startSolicitingRouters() { Data: buffer.View(icmpData).ToVectorisedView(), }) - sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + sent := ndp.ep.stats.icmp.packetsSent + if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. remaining = 0 } else { - sent.RouterSolicit.Increment() + sent.routerSolicit.Increment() remaining-- } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index b1a5a5510..1d38b8b05 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -90,7 +90,7 @@ type testNDPDispatcher struct { addr tcpip.Address } -func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { } func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { @@ -162,10 +162,15 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } -// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a +type linkResolutionResult struct { + linkAddr tcpip.LinkAddress + ok bool +} + +// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. -func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { +func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -231,45 +236,40 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } + ch := make(chan stack.LinkResolutionResult, 1) + err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { + ch <- r + }) - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + wantInvalid := uint64(0) + wantSucccess := true + if len(test.expectedLinkAddr) == 0 { + wantInvalid = 1 + wantSucccess = false + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) } } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + if err != nil { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) } + } - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { + t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + } + if got := invalid.Value(); got != wantInvalid { + t.Errorf("got invalid = %d, want = %d", got, wantInvalid) } }) } } -// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests +// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests // that receiving a valid NDP NS message with the Source Link Layer Address // option results in a new entry in the link address cache for the sender of // the message. -func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { +func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { const nicID = 1 tests := []struct { @@ -382,7 +382,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi } } -func TestNeighorSolicitationResponse(t *testing.T) { +func TestNeighborSolicitationResponse(t *testing.T) { const nicID = 1 nicAddr := lladdr0 remoteAddr := lladdr1 @@ -640,18 +640,12 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatal("expected an NDP NS response") } - if p.Route.LocalAddress != nicAddr { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } respNSDst := header.SolicitedNodeAddr(test.nsSrc) - if p.Route.RemoteAddress != respNSDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) - } - if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) + if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -727,10 +721,10 @@ func TestNeighorSolicitationResponse(t *testing.T) { } } -// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a +// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a // valid NDP NA message with the Target Link Layer Address option results in a // new entry in the link address cache for the target of the message. -func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { +func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -803,45 +797,40 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } + ch := make(chan stack.LinkResolutionResult, 1) + err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { + ch <- r + }) - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + wantInvalid := uint64(0) + wantSucccess := true + if len(test.expectedLinkAddr) == 0 { + wantInvalid = 1 + wantSucccess = false + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) } } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + if err != nil { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) } + } - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { + t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + } + if got := invalid.Value(); got != wantInvalid { + t.Errorf("got invalid = %d, want = %d", got, wantInvalid) } }) } } -// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests +// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests // that receiving a valid NDP NA message with the Target Link Layer Address // option does not result in a new entry in the neighbor cache for the target // of the message. -func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { +func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { const nicID = 1 tests := []struct { @@ -1183,6 +1172,118 @@ func TestNDPValidation(t *testing.T) { } +// TestNeighborAdvertisementValidation tests that the NIC validates received +// Neighbor Advertisements. +// +// In particular, if the IP Destination Address is a multicast address, and the +// Solicited flag is not zero, the Neighbor Advertisement is invalid and should +// be discarded. +func TestNeighborAdvertisementValidation(t *testing.T) { + tests := []struct { + name string + ipDstAddr tcpip.Address + solicitedFlag bool + valid bool + }{ + { + name: "Multicast IP destination address with Solicited flag set", + ipDstAddr: header.IPv6AllNodesMulticastAddress, + solicitedFlag: true, + valid: false, + }, + { + name: "Multicast IP destination address with Solicited flag unset", + ipDstAddr: header.IPv6AllNodesMulticastAddress, + solicitedFlag: false, + valid: true, + }, + { + name: "Unicast IP destination address with Solicited flag set", + ipDstAddr: lladdr0, + solicitedFlag: true, + valid: true, + }, + { + name: "Unicast IP destination address with Solicited flag unset", + ipDstAddr: lladdr0, + solicitedFlag: false, + valid: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.MessageBody()) + na.SetTargetAddress(lladdr1) + na.SetSolicitedFlag(test.solicitedFlag) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: test.ipDstAddr, + }) + + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + rxNA := stats.NeighborAdvert + + if got := rxNA.Value(); got != 0 { + t.Fatalf("got rxNA = %d, want = 0", got) + } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + + if got := rxNA.Value(); got != 1 { + t.Fatalf("got rxNA = %d, want = 1", got) + } + var wantInvalid uint64 = 1 + if test.valid { + wantInvalid = 0 + } + if got := invalid.Value(); got != wantInvalid { + t.Fatalf("got invalid = %d, want = %d", got, wantInvalid) + } + // As per RFC 4861 section 7.2.5: + // When a valid Neighbor Advertisement is received ... + // If no entry exists, the advertisement SHOULD be silently discarded. + // There is no need to create an entry if none exists, since the + // recipient has apparently not initiated any communication with the + // target. + if neighbors, err := s.Neighbors(nicID); err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } else if len(neighbors) != 0 { + t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + }) + } +} + // TestRouterAdvertValidation tests that when the NIC is configured to handle // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go new file mode 100644 index 000000000..0839be3cd --- /dev/null +++ b/pkg/tcpip/network/ipv6/stats.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.IPNetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to the IPv6 protocol family. +type Stats struct { + // IP holds IPv6 statistics. + IP tcpip.IPStats + + // ICMP holds ICMPv6 statistics. + ICMP tcpip.ICMPv6Stats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*Stats) IsNetworkEndpointStats() {} + +// IPStats implements stack.IPNetworkEndointStats +func (s *Stats) IPStats() *tcpip.IPStats { + return &s.IP +} + +type sharedStats struct { + localStats Stats + ip ip.MultiCounterIPStats + icmp multiCounterICMPv6Stats +} + +// LINT.IfChange(multiCounterICMPv6PacketStats) + +type multiCounterICMPv6PacketStats struct { + echoRequest tcpip.MultiCounterStat + echoReply tcpip.MultiCounterStat + dstUnreachable tcpip.MultiCounterStat + packetTooBig tcpip.MultiCounterStat + timeExceeded tcpip.MultiCounterStat + paramProblem tcpip.MultiCounterStat + routerSolicit tcpip.MultiCounterStat + routerAdvert tcpip.MultiCounterStat + neighborSolicit tcpip.MultiCounterStat + neighborAdvert tcpip.MultiCounterStat + redirectMsg tcpip.MultiCounterStat + multicastListenerQuery tcpip.MultiCounterStat + multicastListenerReport tcpip.MultiCounterStat + multicastListenerDone tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6PacketStats) init(a, b *tcpip.ICMPv6PacketStats) { + m.echoRequest.Init(a.EchoRequest, b.EchoRequest) + m.echoReply.Init(a.EchoReply, b.EchoReply) + m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable) + m.packetTooBig.Init(a.PacketTooBig, b.PacketTooBig) + m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded) + m.paramProblem.Init(a.ParamProblem, b.ParamProblem) + m.routerSolicit.Init(a.RouterSolicit, b.RouterSolicit) + m.routerAdvert.Init(a.RouterAdvert, b.RouterAdvert) + m.neighborSolicit.Init(a.NeighborSolicit, b.NeighborSolicit) + m.neighborAdvert.Init(a.NeighborAdvert, b.NeighborAdvert) + m.redirectMsg.Init(a.RedirectMsg, b.RedirectMsg) + m.multicastListenerQuery.Init(a.MulticastListenerQuery, b.MulticastListenerQuery) + m.multicastListenerReport.Init(a.MulticastListenerReport, b.MulticastListenerReport) + m.multicastListenerDone.Init(a.MulticastListenerDone, b.MulticastListenerDone) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6PacketStats) + +// LINT.IfChange(multiCounterICMPv6SentPacketStats) + +type multiCounterICMPv6SentPacketStats struct { + multiCounterICMPv6PacketStats + dropped tcpip.MultiCounterStat + rateLimited tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6SentPacketStats) init(a, b *tcpip.ICMPv6SentPacketStats) { + m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats) + m.dropped.Init(a.Dropped, b.Dropped) + m.rateLimited.Init(a.RateLimited, b.RateLimited) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6SentPacketStats) + +// LINT.IfChange(multiCounterICMPv6ReceivedPacketStats) + +type multiCounterICMPv6ReceivedPacketStats struct { + multiCounterICMPv6PacketStats + unrecognized tcpip.MultiCounterStat + invalid tcpip.MultiCounterStat + routerOnlyPacketsDroppedByHost tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6ReceivedPacketStats) init(a, b *tcpip.ICMPv6ReceivedPacketStats) { + m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats) + m.unrecognized.Init(a.Unrecognized, b.Unrecognized) + m.invalid.Init(a.Invalid, b.Invalid) + m.routerOnlyPacketsDroppedByHost.Init(a.RouterOnlyPacketsDroppedByHost, b.RouterOnlyPacketsDroppedByHost) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6ReceivedPacketStats) + +// LINT.IfChange(multiCounterICMPv6Stats) + +type multiCounterICMPv6Stats struct { + packetsSent multiCounterICMPv6SentPacketStats + packetsReceived multiCounterICMPv6ReceivedPacketStats +} + +func (m *multiCounterICMPv6Stats) init(a, b *tcpip.ICMPv6Stats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6Stats) diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index d0ffc299a..bd62c4482 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -6,8 +6,10 @@ go_library( name = "testutil", srcs = [ "testutil.go", + "testutil_unsafe.go", ], visibility = [ + "//pkg/tcpip/network/arp:__pkg__", "//pkg/tcpip/network/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 3af44991f..f5fa77b65 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -19,6 +19,8 @@ package testutil import ( "fmt" "math/rand" + "reflect" + "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -33,7 +35,7 @@ type MockLinkEndpoint struct { WrittenPackets []*stack.PacketBuffer mtu uint32 - err *tcpip.Error + err tcpip.Error allowPackets int } @@ -41,7 +43,7 @@ type MockLinkEndpoint struct { // // err is the error that will be returned once allowPackets packets are written // to the endpoint. -func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint { +func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { return &MockLinkEndpoint{ mtu: mtu, err: err, @@ -62,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } // WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if ep.allowPackets == 0 { return ep.err } @@ -72,7 +74,7 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip } // WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { var n int for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { @@ -127,3 +129,69 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi } return pkt } + +func checkFieldCounts(ref, multi reflect.Value) error { + refTypeName := ref.Type().Name() + multiTypeName := multi.Type().Name() + refNumField := ref.NumField() + multiNumField := multi.NumField() + + if refNumField != multiNumField { + return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) + } + + return nil +} + +func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { + s, ok := ref.Addr().Interface().(**tcpip.StatCounter) + if !ok { + return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) + } + + // The field names are expected to match (case insensitive). + if !strings.EqualFold(refName, multiName) { + return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) + } + + base := (*s).Value() + m.Increment() + if (*s).Value() != base+1 { + return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) + } + + return nil +} + +// ValidateMultiCounterStats verifies that every counter stored in multi is +// correctly tracking its counterpart in the given counters. +func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { + for _, c := range counters { + if err := checkFieldCounts(c, multi); err != nil { + return err + } + } + + for i := 0; i < multi.NumField(); i++ { + multiName := multi.Type().Field(i).Name + multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) + + if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { + for _, c := range counters { + if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { + return err + } + } + } else { + var countersNextField []reflect.Value + for _, c := range counters { + countersNextField = append(countersNextField, c.Field(i)) + } + if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/tcpip/network/testutil/testutil_unsafe.go b/pkg/tcpip/network/testutil/testutil_unsafe.go new file mode 100644 index 000000000..5ff764800 --- /dev/null +++ b/pkg/tcpip/network/testutil/testutil_unsafe.go @@ -0,0 +1,26 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "reflect" + "unsafe" +) + +// unsafeExposeUnexportedFields takes a Value and returns a version of it in +// which even unexported fields can be read and written. +func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { + return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() +} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 2bad05a2e..57abec5c9 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -18,5 +18,6 @@ go_test( library = ":ports", deps = [ "//pkg/tcpip", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index d87193650..11dbdbbcf 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -329,7 +329,7 @@ func NewPortManager() *PortManager { // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { offset := uint32(rand.Int31n(numEphemeralPorts)) return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) } @@ -348,7 +348,7 @@ func (s *PortManager) incPortHint() { // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) if err == nil { s.incPortHint() @@ -361,7 +361,7 @@ func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uin // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { for i := uint32(0); i < count; i++ { port = uint16(FirstEphemeral + (offset+i)%count) ok, err := testPort(port) @@ -374,7 +374,7 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } } - return 0, tcpip.ErrNoPortAvailable + return 0, &tcpip.ErrNoPortAvailable{} } // IsPortAvailable tests if the given port is available on all given protocols. @@ -404,7 +404,7 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // An optional testPort closure can be passed in which if provided will be used // to test if the picked port can be used. The function should return true if // the port is safe to use, false otherwise. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() @@ -414,17 +414,17 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // protocols. if port != 0 { if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { - return 0, tcpip.ErrPortInUse + return 0, &tcpip.ErrPortInUse{} } if testPort != nil && !testPort(port) { s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst) - return 0, tcpip.ErrPortInUse + return 0, &tcpip.ErrPortInUse{} } return port, nil } // A port wasn't specified, so try to find one. - return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) { return false, nil } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 4bc949fd8..e70fbb72b 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -32,7 +33,7 @@ const ( type portReserveTestAction struct { port uint16 ip tcpip.Address - want *tcpip.Error + want tcpip.Error flags Flags release bool device tcpip.NICID @@ -50,19 +51,19 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress, want: nil}, {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ - {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, + {port: 80, ip: anyIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 80, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, }, }, { tname: "bind to inaddr any", actions: []portReserveTestAction{ {port: 22, ip: anyIPAddress, want: nil}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -80,8 +81,8 @@ func TestPortReservation(t *testing.T) { {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, @@ -91,14 +92,14 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, @@ -107,7 +108,7 @@ func TestPortReservation(t *testing.T) { tname: "bind twice with device fails", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 3, want: nil}, - {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 3, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind to device", @@ -119,50 +120,50 @@ func TestPortReservation(t *testing.T) { tname: "bind to device and then without device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind without device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, want: nil}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "binding with reuseport and device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 999, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "mixing reuseport and not reuseport by binding to device", @@ -177,14 +178,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind and release", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, // Release the bind to device 0 and try again. @@ -195,7 +196,7 @@ func TestPortReservation(t *testing.T) { tname: "bind twice with reuseport once", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "release an unreserved device", @@ -213,16 +214,16 @@ func TestPortReservation(t *testing.T) { tname: "bind with reuseaddr", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, }, }, { tname: "bind twice with reuseaddr once", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport", @@ -236,14 +237,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport, and then reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", @@ -264,14 +265,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseport, and then reuseaddr and reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", @@ -283,7 +284,7 @@ func TestPortReservation(t *testing.T) { tname: "bind tuple with reuseaddr, and then wildcard", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", @@ -295,7 +296,7 @@ func TestPortReservation(t *testing.T) { tname: "bind tuple with reuseaddr, and then wildcard", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind two tuples with reuseaddr", @@ -313,7 +314,7 @@ func TestPortReservation(t *testing.T) { tname: "bind wildcard, and then tuple with reuseaddr", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind wildcard twice with reuseaddr", @@ -333,8 +334,8 @@ func TestPortReservation(t *testing.T) { continue } gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) - if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) + if diff := cmp.Diff(test.want, err); diff != "" { + t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) @@ -345,30 +346,29 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { - customErr := &tcpip.Error{} for _, test := range []struct { name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error + f func(port uint16) (bool, tcpip.Error) + wantErr tcpip.Error wantPort uint16 }{ { name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, { name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr + f: func(port uint16) (bool, tcpip.Error) { + return false, &tcpip.ErrBadBuffer{} }, - wantErr: customErr, + wantErr: &tcpip.ErrBadBuffer{}, }, { name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port == FirstEphemeral+42 { return true, nil } @@ -378,49 +378,52 @@ func TestPickEphemeralPort(t *testing.T) { }, { name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port < FirstEphemeral { return true, nil } return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() - if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + port, err := pm.PickEphemeralPort(test.f) + if diff := cmp.Diff(test.wantErr, err); diff != "" { + t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) + } + if port != test.wantPort { + t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) } }) } } func TestPickEphemeralPortStable(t *testing.T) { - customErr := &tcpip.Error{} for _, test := range []struct { name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error + f func(port uint16) (bool, tcpip.Error) + wantErr tcpip.Error wantPort uint16 }{ { name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, { name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr + f: func(port uint16) (bool, tcpip.Error) { + return false, &tcpip.ErrBadBuffer{} }, - wantErr: customErr, + wantErr: &tcpip.ErrBadBuffer{}, }, { name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port == FirstEphemeral+42 { return true, nil } @@ -430,20 +433,24 @@ func TestPickEphemeralPortStable(t *testing.T) { }, { name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port < FirstEphemeral { return true, nil } return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) - if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + port, err := pm.PickEphemeralPortStable(portOffset, test.f) + if diff := cmp.Diff(test.wantErr, err); diff != "" { + t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) + } + if port != test.wantPort { + t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) } }) } diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index cf0a5fefe..db9b91815 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -8,7 +8,6 @@ go_binary( visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/rawfile", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 3b4f900e3..856ea998d 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -41,7 +41,7 @@ package main import ( - "bufio" + "bytes" "fmt" "log" "math/rand" @@ -51,7 +51,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" @@ -71,24 +70,21 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) { close(ch) }() - r := bufio.NewReader(os.Stdin) - for { - v := buffer.NewView(1024) - n, err := r.Read(v) - if err != nil { - return - } - - v.CapLength(n) - for len(v) > 0 { - n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - if err != nil { - fmt.Println("Write failed:", err) - return + var b bytes.Buffer + if err := func() error { + for { + if _, err := b.ReadFrom(os.Stdin); err != nil { + return fmt.Errorf("b.ReadFrom failed: %w", err) } - v.TrimFront(int(n)) + for b.Len() != 0 { + if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil { + return fmt.Errorf("ep.Write failed: %s", err) + } + } } + }(); err != nil { + fmt.Println(err) } } @@ -179,7 +175,7 @@ func main() { waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventOut) terr := ep.Connect(remote) - if terr == tcpip.ErrConnectStarted { + if _, ok := terr.(*tcpip.ErrConnectStarted); ok { fmt.Println("Connect is pending...") <-notifyCh terr = ep.LastError() @@ -202,11 +198,11 @@ func main() { for { _, err := ep.Read(os.Stdout, tcpip.ReadOptions{}) if err != nil { - if err == tcpip.ErrClosedForReceive { + if _, ok := err.(*tcpip.ErrClosedForReceive); ok { break } - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 3ac562756..9b23df3a9 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -20,6 +20,7 @@ package main import ( + "bytes" "flag" "io" "log" @@ -50,7 +51,7 @@ type endpointWriter struct { } type tcpipError struct { - inner *tcpip.Error + inner tcpip.Error } func (e *tcpipError) Error() string { @@ -58,7 +59,9 @@ func (e *tcpipError) Error() string { } func (e *endpointWriter) Write(p []byte) (int, error) { - n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(p) + n, err := e.ep.Write(&r, tcpip.WriteOptions{}) if err != nil { return int(n), &tcpipError{ inner: err, @@ -86,7 +89,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { for { _, err := ep.Read(&w, tcpip.ReadOptions{}) if err != nil { - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } @@ -214,7 +217,7 @@ func main() { for { n, wq, err := ep.Accept(nil) if err != nil { - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index f3ad40fdf..019d6a63c 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -15,11 +15,16 @@ package tcpip import ( + "math" "sync/atomic" "gvisor.dev/gvisor/pkg/sync" ) +// PacketOverheadFactor is used to multiply the value provided by the user on a +// SetSockOpt for setting the send/receive buffer sizes sockets. +const PacketOverheadFactor = 2 + // SocketOptionsHandler holds methods that help define endpoint specific // behavior for socket level socket options. These must be implemented by // endpoints to get notified when socket level options are set. @@ -41,13 +46,19 @@ type SocketOptionsHandler interface { OnCorkOptionSet(v bool) // LastError is invoked when SO_ERROR is read for an endpoint. - LastError() *Error + LastError() Error // UpdateLastError updates the endpoint specific last error field. - UpdateLastError(err *Error) + UpdateLastError(err Error) // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. HasNIC(v int32) bool + + // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE. + GetSendBufferSize() (int64, Error) + + // IsUnixSocket is invoked to check if the socket is of unix domain. + IsUnixSocket() bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -72,18 +83,39 @@ func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {} func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {} // LastError implements SocketOptionsHandler.LastError. -func (*DefaultSocketOptionsHandler) LastError() *Error { +func (*DefaultSocketOptionsHandler) LastError() Error { return nil } // UpdateLastError implements SocketOptionsHandler.UpdateLastError. -func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} +func (*DefaultSocketOptionsHandler) UpdateLastError(Error) {} // HasNIC implements SocketOptionsHandler.HasNIC. func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { return false } +// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize. +func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) { + return 0, nil +} + +// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket. +func (*DefaultSocketOptionsHandler) IsUnixSocket() bool { + return false +} + +// StackHandler holds methods to access the stack options. These must be +// implemented by the stack. +type StackHandler interface { + // Option allows retrieving stack wide options. + Option(option interface{}) Error + + // TransportProtocolOption allows retrieving individual protocol level + // option values. + TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) Error +} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -91,6 +123,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { type SocketOptions struct { handler SocketOptionsHandler + // StackHandler is initialized at the creation time and will not change. + stackHandler StackHandler `state:"manual"` + // These fields are accessed and modified using atomic operations. // broadcastEnabled determines whether datagram sockets are allowed to @@ -170,6 +205,14 @@ type SocketOptions struct { // bindToDevice determines the device to which the socket is bound. bindToDevice int32 + // getSendBufferLimits provides the handler to get the min, default and + // max size for send buffer. It is initialized at the creation time and + // will not change. + getSendBufferLimits GetSendBufferLimits `state:"manual"` + + // sendBufferSize determines the send buffer size for this socket. + sendBufferSize int64 + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -180,8 +223,10 @@ type SocketOptions struct { // InitHandler initializes the handler. This must be called before using the // socket options utility. -func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) { +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) { so.handler = handler + so.stackHandler = stack + so.getSendBufferLimits = getSendBufferLimits } func storeAtomicBool(addr *uint32, v bool) { @@ -193,7 +238,7 @@ func storeAtomicBool(addr *uint32, v bool) { } // SetLastError sets the last error for a socket. -func (so *SocketOptions) SetLastError(err *Error) { +func (so *SocketOptions) SetLastError(err Error) { so.handler.UpdateLastError(err) } @@ -378,7 +423,7 @@ func (so *SocketOptions) SetRecvError(v bool) { } // GetLastError gets value for SO_ERROR option. -func (so *SocketOptions) GetLastError() *Error { +func (so *SocketOptions) GetLastError() Error { return so.handler.LastError() } @@ -435,7 +480,7 @@ type SockError struct { sockErrorEntry // Err is the error caused by the errant packet. - Err *Error + Err Error // ErrOrigin indicates the error origin. ErrOrigin SockErrOrigin // ErrType is the type in the ICMP header. @@ -493,7 +538,7 @@ func (so *SocketOptions) QueueErr(err *SockError) { } // QueueLocalErr queues a local error onto the local queue. -func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { +func (so *SocketOptions) QueueLocalErr(err Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { so.QueueErr(&SockError{ Err: err, ErrOrigin: SockExtErrorOriginLocal, @@ -510,11 +555,52 @@ func (so *SocketOptions) GetBindToDevice() int32 { } // SetBindToDevice sets value for SO_BINDTODEVICE option. -func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { +func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { if !so.handler.HasNIC(bindToDevice) { - return ErrUnknownDevice + return &ErrUnknownDevice{} } atomic.StoreInt32(&so.bindToDevice, bindToDevice) return nil } + +// GetSendBufferSize gets value for SO_SNDBUF option. +func (so *SocketOptions) GetSendBufferSize() (int64, Error) { + if so.handler.IsUnixSocket() { + return so.handler.GetSendBufferSize() + } + return atomic.LoadInt64(&so.sendBufferSize), nil +} + +// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the +// stack handler should be invoked to set the send buffer size. +func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { + if so.handler.IsUnixSocket() { + return + } + + v := sendBufferSize + if notify { + // TODO(b/176170271): Notify waiters after size has grown. + // Make sure the send buffer size is within the min and max + // allowed. + ss := so.getSendBufferLimits(so.stackHandler) + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + // Multiply it by factor of 2. + if v > max { + v = max + } + + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min + } + } else { + v = math.MaxInt32 + } + } + atomic.StoreInt64(&so.sendBufferSize, v) +} diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index cd423bf71..e5590ecc0 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,7 +117,7 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) @@ -143,10 +143,10 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr // AddAndAcquireTemporaryAddress adds a temporary address. // -// Returns tcpip.ErrDuplicateAddress if the address exists. +// Returns *tcpip.ErrDuplicateAddress if the address exists. // // The temporary address's endpoint is acquired and returned. -func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) @@ -176,11 +176,11 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // If the addressable endpoint already has the address in a non-permanent state, // and addAndAcquireAddressLocked is adding a permanent address, that address is // promoted in place and its properties set to the properties provided. If the -// address already exists in any other state, then tcpip.ErrDuplicateAddress is +// address already exists in any other state, then *tcpip.ErrDuplicateAddress is // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -190,7 +190,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We are adding a non-permanent address but the address exists. No need // to go any further since we can only promote existing temporary/expired // addresses to permanent. - return nil, tcpip.ErrDuplicateAddress + return nil, &tcpip.ErrDuplicateAddress{} } addrState.mu.Lock() @@ -198,7 +198,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState.mu.Unlock() // We are adding a permanent address but a permanent address already // exists. - return nil, tcpip.ErrDuplicateAddress + return nil, &tcpip.ErrDuplicateAddress{} } if addrState.mu.refs == 0 { @@ -293,7 +293,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // RemovePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { a.mu.Lock() defer a.mu.Unlock() return a.removePermanentAddressLocked(addr) @@ -303,10 +303,10 @@ func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *t // requirements. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { +func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) tcpip.Error { addrState, ok := a.mu.endpoints[addr] if !ok { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } return a.removePermanentEndpointLocked(addrState) @@ -314,10 +314,10 @@ func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Addre // RemovePermanentEndpoint removes the passed endpoint if it is associated with // a and permanent. -func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error { +func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) tcpip.Error { addrState, ok := ep.(*addressState) if !ok || addrState.addressableEndpointState != a { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } a.mu.Lock() @@ -329,9 +329,9 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * // requirements. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error { +func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) tcpip.Error { if !addrState.GetKind().IsPermanent() { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } addrState.SetKind(PermanentExpired) @@ -574,9 +574,11 @@ func (a *AddressableEndpointState) Cleanup() { defer a.mu.Unlock() for _, ep := range a.mu.endpoints { - // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is + // removePermanentEndpointLocked returns *tcpip.ErrBadLocalAddress if ep is // not a permanent address. - if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := a.removePermanentEndpointLocked(ep); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err)) } } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5e649cca6..54617f2e6 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -198,15 +198,15 @@ type bucket struct { // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { +func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, tcpip.ErrUnknownProtocol + return tupleID{}, &tcpip.ErrUnknownProtocol{} } tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { - return tupleID{}, tcpip.ErrUnknownProtocol + return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ @@ -617,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -631,10 +631,10 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ conn, _ := ct.connForTID(tid) if conn == nil { // Not a tracked connection. - return "", 0, tcpip.ErrNotConnected + return "", 0, &tcpip.ErrNotConnected{} } else if conn.manip == manipNone { // Unmanipulated connection. - return "", 0, tcpip.ErrInvalidOptionValue + return "", 0, &tcpip.ErrInvalidOptionValue{} } return conn.original.dstAddr, conn.original.dstPort, nil diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 4908848e9..63a42a2ea 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -41,6 +41,8 @@ const ( protocolNumberOffset = 2 ) +var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) + // fwdTestNetworkEndpoint is a network-layer protocol endpoint. // Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only // use the first three: destination address, source address, and transport @@ -53,9 +55,7 @@ type fwdTestNetworkEndpoint struct { dispatcher TransportDispatcher } -var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) - -func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { +func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { return nil } @@ -104,7 +104,7 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { +func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { return 0 } @@ -112,7 +112,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) @@ -124,14 +124,14 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error { // The network header should not already be populated. if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) @@ -141,6 +141,21 @@ func (f *fwdTestNetworkEndpoint) Close() { f.AddressableEndpointState.Cleanup() } +// Stats implements stack.NetworkEndpoint. +func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats { + return &fwdTestNetworkEndpointStats{} +} + +var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil) + +type fwdTestNetworkEndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} + +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) +var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) + // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { @@ -158,18 +173,15 @@ type fwdTestNetworkProtocol struct { } } -var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) -var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) - -func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } -func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { +func (*fwdTestNetworkProtocol) MinimumPacketSize() int { return fwdTestNetHeaderLen } -func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { +func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { return fwdTestNetDefaultPrefixLen } @@ -195,19 +207,19 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddress return e } -func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } -func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) @@ -307,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -323,7 +335,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.WritePacket(r, gso, protocol, pkt) @@ -356,10 +368,6 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC UseNeighborCache: useNeighborCache, }) - if !useNeighborCache { - proto.addrCache = s.linkAddrCache - } - // Enable forwarding. s.SetForwarding(proto.Number(), true) @@ -389,13 +397,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC t.Fatal("AddAddress #2 failed:", err) } + nic, ok := s.nics[2] + if !ok { + t.Fatal("NIC 2 does not exist") + } if useNeighborCache { // Control the neighbor cache for NIC 2. - nic, ok := s.nics[2] - if !ok { - t.Fatal("failed to get the neighbor cache for NIC 2") - } proto.neigh = nic.neigh + } else { + proto.addrCache = nic.linkAddrCache } // Route all packets to NIC 2. @@ -481,7 +491,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -607,7 +617,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") } }, }, @@ -692,7 +702,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -768,7 +778,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -858,7 +868,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 09c7811fa..63832c200 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -229,7 +229,7 @@ func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { // ReplaceTable replaces or inserts table by name. It panics when an invalid id // is provided. -func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -267,11 +267,11 @@ const ( // dropped. // // TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from -// which address and nicName can be gathered. Currently, address is only -// needed for prerouting and nicName is only needed for output. +// which address can be gathered. Currently, address is only needed for +// prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { + if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(pkt, hook, nicName) { + if !rule.Filter.match(pkt, hook, inNicName, outNicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName) if hotdrop { return RuleDrop, 0 } @@ -483,11 +483,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { - return "", 0, tcpip.ErrNotConnected + return "", 0, &tcpip.ErrNotConnected{} } return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 56a3e7861..fd9d61e39 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -210,8 +210,19 @@ type IPHeaderFilter struct { // filter will match packets that fail the source comparison. SrcInvert bool - // OutputInterface matches the name of the outgoing interface for the - // packet. + // InputInterface matches the name of the incoming interface for the packet. + InputInterface string + + // InputInterfaceMask masks the characters of the interface name when + // comparing with InputInterface. + InputInterfaceMask string + + // InputInterfaceInvert inverts the meaning of incoming interface check, + // i.e. when true the filter will match packets that fail the incoming + // interface comparison. + InputInterfaceInvert bool + + // OutputInterface matches the name of the outgoing interface for the packet. OutputInterface string // OutputInterfaceMask masks the characters of the interface name when @@ -228,7 +239,7 @@ type IPHeaderFilter struct { // // Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 // or IPv6 header length. -func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( // TODO(gvisor.dev/issue/170): Support other filter fields. @@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return false } - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(fl.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which - // begins with the name should be matched. - ifName := fl.OutputInterface - matches := nicName == ifName - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } - return fl.OutputInterfaceInvert != matches + switch hook { + case Prerouting, Input: + return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) + case Output: + return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) + case Forward, Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + return true + default: + panic(fmt.Sprintf("unknown hook: %d", hook)) } +} - return true +func matchIfName(nicName string, ifName string, invert bool) bool { + n := len(ifName) + if n == 0 { + // If the interface name is omitted in the filter, any interface will match. + return true + } + // If the interface name ends with '+', any interface which begins with the + // name should be matched. + var matches bool + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return matches != invert } // NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header @@ -320,7 +340,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index b600a1cab..930b8f795 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -24,12 +24,16 @@ import ( const linkAddrCacheSize = 512 // max cache entries +var _ LinkAddressCache = (*linkAddrCache)(nil) + // linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. // // The entries are stored in a ring buffer, oldest entry replaced first. // // This struct is safe for concurrent use. type linkAddrCache struct { + nic *NIC + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -41,9 +45,9 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - cache struct { + mu struct { sync.Mutex - table map[tcpip.FullAddress]*linkAddrEntry + table map[tcpip.Address]*linkAddrEntry lru linkAddrEntryList } } @@ -77,31 +81,42 @@ type linkAddrEntry struct { // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry - // TODO(gvisor.dev/issue/5150): move these fields under mu. - // mu protects the fields below. - mu sync.RWMutex + cache *linkAddrCache - addr tcpip.FullAddress - linkAddr tcpip.LinkAddress - expiration time.Time - s entryState + mu struct { + sync.RWMutex - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} + addr tcpip.Address + linkAddr tcpip.LinkAddress + expiration time.Time + s entryState - // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. + done chan struct{} + + // onResolve is called with the result of address resolution. + onResolve []func(LinkResolutionResult) + } } func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { - for _, callback := range e.onResolve { - callback(linkAddr, len(linkAddr) != 0) + res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} + for _, callback := range e.mu.onResolve { + callback(res) } - e.onResolve = nil - if ch := e.done; ch != nil { + e.mu.onResolve = nil + if ch := e.mu.done; ch != nil { close(ch) - e.done = nil + e.mu.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0) } } @@ -114,30 +129,30 @@ func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { // // Precondition: e.mu must be locked func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { - if e.s == incomplete && ns == ready { - e.notifyCompletionLocked(e.linkAddr) + if e.mu.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.mu.linkAddr) } - if expiration.IsZero() || expiration.After(e.expiration) { - e.expiration = expiration + if expiration.IsZero() || expiration.After(e.mu.expiration) { + e.mu.expiration = expiration } - e.s = ns + e.mu.s = ns } // add adds a k -> v mapping to the cache. -func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { +func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is // relative to the time when information was learned, rather than when it // happened to be inserted into the cache. expiration := time.Now().Add(c.ageLimit) - c.cache.Lock() + c.mu.Lock() entry := c.getOrCreateEntryLocked(k) - c.cache.Unlock() - entry.mu.Lock() defer entry.mu.Unlock() - entry.linkAddr = v + c.mu.Unlock() + + entry.mu.linkAddr = v entry.changeStateLocked(ready, expiration) } @@ -150,19 +165,19 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { - if entry, ok := c.cache.table[k]; ok { - c.cache.lru.Remove(entry) - c.cache.lru.PushFront(entry) +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { + if entry, ok := c.mu.table[k]; ok { + c.mu.lru.Remove(entry) + c.mu.lru.PushFront(entry) return entry } var entry *linkAddrEntry - if len(c.cache.table) == linkAddrCacheSize { - entry = c.cache.lru.Back() + if len(c.mu.table) == linkAddrCacheSize { + entry = c.mu.lru.Back() entry.mu.Lock() - delete(c.cache.table, entry.addr) - c.cache.lru.Remove(entry) + delete(c.mu.table, entry.mu.addr) + c.mu.lru.Remove(entry) // Wake waiters and mark the soon-to-be-reused entry as expired. entry.notifyCompletionLocked("" /* linkAddr */) @@ -172,53 +187,56 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } *entry = linkAddrEntry{ - addr: k, - s: incomplete, + cache: c, } - c.cache.table[k] = entry - c.cache.lru.PushFront(entry) + entry.mu.Lock() + entry.mu.addr = k + entry.mu.s = incomplete + entry.mu.Unlock() + c.mu.table[k] = entry + c.mu.lru.PushFront(entry) return entry } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { - c.cache.Lock() - defer c.cache.Unlock() +func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + c.mu.Lock() entry := c.getOrCreateEntryLocked(k) entry.mu.Lock() defer entry.mu.Unlock() + c.mu.Unlock() - switch s := entry.s; s { + switch s := entry.mu.s; s { case ready: - if !time.Now().After(entry.expiration) { + if !time.Now().After(entry.mu.expiration) { // Not expired. if onResolve != nil { - onResolve(entry.linkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true}) } - return entry.linkAddr, nil, nil + return entry.mu.linkAddr, nil, nil } entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: if onResolve != nil { - entry.onResolve = append(entry.onResolve, onResolve) + entry.mu.onResolve = append(entry.mu.onResolve, onResolve) } - if entry.done == nil { - entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + if entry.mu.done == nil { + entry.mu.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) + linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): @@ -234,10 +252,10 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has // succeeded and mark the entry accordingly. Returns true if request can stop, // false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { - c.cache.Lock() - defer c.cache.Unlock() - entry, ok := c.cache.table[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool { + c.mu.Lock() + defer c.mu.Unlock() + entry, ok := c.mu.table[k] if !ok { // Entry was evicted from the cache. return true @@ -245,7 +263,7 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att entry.mu.Lock() defer entry.mu.Unlock() - switch s := entry.s; s { + switch s := entry.mu.s; s { case ready: // Entry was made ready by resolver. case incomplete: @@ -255,19 +273,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att } // Max number of retries reached, delete entry. entry.notifyCompletionLocked("" /* linkAddr */) - delete(c.cache.table, k) + delete(c.mu.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } return true } -func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { +func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { c := &linkAddrCache{ + nic: nic, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } - c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d7ac6cf5f..466a5e8d9 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -26,7 +26,7 @@ import ( ) type testaddr struct { - addr tcpip.FullAddress + addr tcpip.Address linkAddr tcpip.LinkAddress } @@ -35,7 +35,7 @@ var testAddrs = func() []testaddr { for i := 0; i < 4*linkAddrCacheSize; i++ { addr := fmt.Sprintf("Addr%06d", i) addrs = append(addrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + addr: tcpip.Address(addr), linkAddr: tcpip.LinkAddress("Link" + addr), }) } @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { @@ -59,8 +59,8 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { for _, ta := range testAddrs { - if ta.addr.Addr == addr { - r.cache.add(ta.addr, ta.linkAddr) + if ta.addr == addr { + r.cache.AddLinkAddress(ta.addr, ta.linkAddr) break } } @@ -77,13 +77,13 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { got, ch, err := c.get(addr, linkRes, "", nil, nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { - return got, tcpip.ErrTimeout + return got, &tcpip.ErrTimeout{} } attemptedResolution = true <-ch @@ -93,17 +93,23 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } +func newEmptyNIC() *NIC { + n := &NIC{} + n.linkResQueue.init(n) + return n +} + func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. @@ -111,25 +117,25 @@ func TestCacheOverflow(t *testing.T) { e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. - c.cache.Lock() - defer c.cache.Unlock() + c.mu.Lock() + defer c.mu.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) + if entry, ok := c.mu.table[e.addr]; ok { + t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) } } } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup @@ -137,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) { wg.Add(1) go func() { for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) } wg.Done() }() @@ -150,52 +156,53 @@ func TestCacheConcurrent(t *testing.T) { e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] - c.cache.Lock() - defer c.cache.Unlock() - if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) + c.mu.Lock() + defer c.mu.Unlock() + if entry, ok := c.mu.table[e.addr]; ok { + t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) } } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) linkRes := &testLinkAddressResolver{cache: c} e := testAddrs[0] - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) + _, _, err := c.get(e.addr, linkRes, "", nil, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err) } } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) e := testAddrs[0] l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } - c.add(e.addr, l2) + c.AddLinkAddress(e.addr, l2) got, _, err = c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) } if got != l2 { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) + t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2) } } @@ -206,15 +213,15 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) linkRes := &testLinkAddressResolver{cache: c} for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) + t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } @@ -223,16 +230,16 @@ func TestCacheResolution(t *testing.T) { e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} var requestCount uint32 @@ -244,17 +251,18 @@ func TestCacheResolutionFailed(t *testing.T) { e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) - e.addr.Addr += "2" - if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + e.addr += "2" + a, err := getBlocking(c, e.addr, linkRes) + if _, ok := err.(*tcpip.ErrTimeout); !ok { + t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -265,11 +273,12 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) + c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} e := testAddrs[0] - if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + a, err := getBlocking(c, e.addr, linkRes) + if _, ok := err.(*tcpip.ErrTimeout); !ok { + t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 61636cae5..64383bc7c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -45,6 +45,8 @@ const ( linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") + defaultPrefixLen = 128 + // Extra time to use when waiting for an async event to occur. defaultAsyncPositiveEventTimeout = 10 * time.Second @@ -102,7 +104,7 @@ type ndpDADEvent struct { nicID tcpip.NICID addr tcpip.Address resolved bool - err *tcpip.Error + err tcpip.Error } type ndpRouterEvent struct { @@ -172,7 +174,7 @@ type ndpDispatcher struct { } // Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ nicID, @@ -309,7 +311,7 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string { return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) } @@ -330,8 +332,12 @@ func TestDADDisabled(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } // Should get the address immediately since we should not have performed @@ -344,12 +350,8 @@ func TestDADDisabled(t *testing.T) { default: t.Fatal("expected DAD event") } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatal(err) } // We should not have sent any NDP NS messages. @@ -440,31 +442,31 @@ func TestDADResolve(t *testing.T) { NIC: nicID, }}) - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Make sure the address does not resolve before the resolution time has // passed. time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) } // Should not get a route even if we specify the local address as the // tentative address. { r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if err != tcpip.ErrNoRoute { - t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) } if r != nil { r.Release() @@ -472,8 +474,8 @@ func TestDADResolve(t *testing.T) { } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if err != tcpip.ErrNoRoute { - t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) } if r != nil { r.Release() @@ -493,10 +495,8 @@ func TestDADResolve(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if addr.Address != addr1 { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Error(err) } // Should get a route using the address now that it is resolved. { @@ -662,12 +662,8 @@ func TestDADFail(t *testing.T) { // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Receive a packet to simulate an address conflict. @@ -691,12 +687,8 @@ func TestDADFail(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Attempting to add the address again should not fail if the address's @@ -777,12 +769,8 @@ func TestDADStop(t *testing.T) { } // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } test.stopFn(t, s) @@ -800,12 +788,8 @@ func TestDADStop(t *testing.T) { } if !test.skipFinalAddrCheck { - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } } @@ -901,26 +885,25 @@ func TestSetNDPConfigurations(t *testing.T) { } // Add addresses for each NIC. - if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) } - if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err) + addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) } expectDADEvent(nicID2, addr2) - if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err) + addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) } expectDADEvent(nicID3, addr3) // Address should not be considered bound to NIC(1) yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Should get the address on NIC(2) and NIC(3) @@ -928,31 +911,19 @@ func TestSetNDPConfigurations(t *testing.T) { // it as the stack was configured to not do DAD by // default and we only updated the NDP configurations on // NIC(1). - addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2) + if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { + t.Fatal(err) } - addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3) + if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { + t.Fatal(err) } // Sleep until right (500ms before) before resolution to // make sure the address didn't resolve on NIC(1) yet. const delta = 500 * time.Millisecond time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Wait for DAD to resolve. @@ -970,12 +941,8 @@ func TestSetNDPConfigurations(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { + t.Fatal(err) } }) } @@ -2808,6 +2775,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2827,10 +2795,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN Gateway: llAddr3, NIC: nicID, }}) + if useNeighborCache { - s.AddStaticNeighbor(nicID, llAddr3, linkAddr3) + if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + } } else { - s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + } } return ndpDisp, e, s } @@ -2940,10 +2913,8 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { @@ -3088,10 +3059,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { @@ -3238,10 +3207,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("should not have %s in the list of addresses", addr2) } // Should not have any primary endpoints. - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); got != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) @@ -3255,8 +3222,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { defer ep.Close() ep.SocketOptions().SetV6Only(true) - if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) + { + err := ep.Connect(dstAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) + } } }) } @@ -3615,10 +3585,8 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index acee72572..88a3ff776 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() @@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // a node continues sending packets to that neighbor using the cached // link-layer address." if onResolve != nil { - onResolve(entry.neigh.LinkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true}) } return entry.neigh, nil, nil case Unknown, Incomplete, Failed: @@ -154,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } entry.handlePacketQueuedLocked(localAddr) - return entry.neigh, entry.done, tcpip.ErrWouldBlock + return entry.neigh, entry.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } @@ -297,10 +297,9 @@ func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Li // no matching entry for the remote address. } -// HandleUpperLevelConfirmation implements -// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in -// RFC 4861 section 7.3.1. -func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { +// handleUpperLevelConfirmation processes a confirmation of reachablity from +// some protocol that operates at a layer above the IP/link layer. +func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b96a56612..2870e4f66 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -194,7 +194,7 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { if !r.dropReplies { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { @@ -251,8 +251,8 @@ func TestNeighborCacheGetConfig(t *testing.T) { // No events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -273,8 +273,8 @@ func TestNeighborCacheSetConfig(t *testing.T) { // No events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -295,8 +295,9 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -321,11 +322,11 @@ func TestNeighborCacheEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { @@ -335,8 +336,8 @@ func TestNeighborCacheEntry(t *testing.T) { // No more events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -359,8 +360,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -385,11 +387,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } neigh.removeEntry(entry.Addr) @@ -407,15 +409,18 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + { + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + } } } @@ -462,8 +467,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -506,11 +512,11 @@ func (c *testContext) overflowCache(opts overflowOptions) error { }) c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -531,15 +537,15 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } return nil @@ -579,8 +585,9 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ @@ -603,11 +610,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the entry @@ -626,11 +633,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -668,11 +675,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the static entry that was just added @@ -681,8 +688,8 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -712,11 +719,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Add a duplicate entry with a different link address @@ -736,8 +743,8 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) } c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -774,11 +781,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the static entry that was just added @@ -796,11 +803,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -830,8 +837,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -854,11 +862,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Override the entry with a static one using the same address @@ -886,11 +894,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -932,8 +940,8 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { LinkAddr: entry.LinkAddr, State: Static, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -948,11 +956,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } opts := overflowOptions{ @@ -989,8 +997,9 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -1014,11 +1023,11 @@ func TestNeighborCacheClear(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Add a static entry. @@ -1037,11 +1046,11 @@ func TestNeighborCacheClear(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1072,8 +1081,8 @@ func TestNeighborCacheClear(t *testing.T) { } nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEvents, nudDisp.events, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1094,8 +1103,9 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1118,11 +1128,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Clear the cache. @@ -1140,11 +1150,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1188,16 +1198,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1225,11 +1232,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1247,16 +1254,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1299,11 +1303,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1331,15 +1335,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } // No more events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1366,8 +1370,10 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) { + case nil, *tcpip.ErrWouldBlock: + default: + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) } }(entry) } @@ -1398,8 +1404,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1423,16 +1429,13 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1455,8 +1458,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: entry.LinkAddr, State: Reachable, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } @@ -1489,8 +1492,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: updatedLinkAddr, State: Delay, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } @@ -1507,8 +1510,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: updatedLinkAddr, State: Reachable, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } } @@ -1539,16 +1542,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1567,8 +1567,8 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { LinkAddr: entry.LinkAddr, State: Reachable, } - if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } // Verify address resolution fails for an unknown address. @@ -1576,19 +1576,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1627,19 +1621,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1674,19 +1662,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1713,13 +1695,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) @@ -1772,16 +1754,13 @@ func BenchmarkCacheClear(b *testing.B) { b.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - b.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 75afb3001..53ac9bb6e 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -96,7 +96,7 @@ type neighborEntry struct { done chan struct{} // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + onResolve []func(LinkResolutionResult) isRouter bool job *tcpip.Job @@ -143,13 +143,22 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd // // Precondition: e.mu MUST be locked. func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded} for _, callback := range e.onResolve { - callback(e.neigh.LinkAddr, succeeded) + callback(res) } e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index ec34ffa5a..140b8ca00 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { p := entryTestProbeInfo{ RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, @@ -266,16 +266,16 @@ func TestEntryInitiallyUnknown(t *testing.T) { // No probes should have been sent. linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } // No events should have been dispatched. nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -299,16 +299,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { // No probes should have been sent. linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } // No events should have been dispatched. nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -333,10 +333,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -352,10 +352,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) { } { nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -374,10 +374,10 @@ func TestEntryUnknownToStale(t *testing.T) { // No probes should have been sent. runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -392,8 +392,8 @@ func TestEntryUnknownToStale(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -427,11 +427,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -453,10 +453,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -483,8 +483,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -515,10 +515,10 @@ func TestEntryIncompleteToReachable(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -553,8 +553,8 @@ func TestEntryIncompleteToReachable(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -579,10 +579,10 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -620,8 +620,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -646,10 +646,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -684,8 +684,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -710,10 +710,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -744,8 +744,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -785,10 +785,10 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -812,8 +812,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -850,10 +850,10 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -903,8 +903,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -932,10 +932,10 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -977,8 +977,8 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1005,10 +1005,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1054,8 +1054,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -1083,10 +1083,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1134,8 +1134,8 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1157,10 +1157,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1212,8 +1212,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1235,10 +1235,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1290,8 +1290,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1313,10 +1313,10 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1358,8 +1358,8 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1381,10 +1381,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1439,8 +1439,8 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1462,10 +1462,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1520,8 +1520,8 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1543,10 +1543,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1601,8 +1601,8 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1624,10 +1624,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1678,8 +1678,8 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1701,10 +1701,10 @@ func TestEntryStaleToDelay(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1752,8 +1752,8 @@ func TestEntryStaleToDelay(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1780,10 +1780,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1851,8 +1851,8 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1880,10 +1880,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1958,8 +1958,8 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1987,10 +1987,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2065,8 +2065,8 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2088,10 +2088,10 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2147,8 +2147,8 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2170,10 +2170,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2231,8 +2231,8 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2254,10 +2254,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2319,8 +2319,8 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2343,11 +2343,11 @@ func TestEntryDelayToProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2372,10 +2372,10 @@ func TestEntryDelayToProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2418,8 +2418,8 @@ func TestEntryDelayToProbe(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -2448,11 +2448,11 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2474,10 +2474,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2539,8 +2539,8 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2563,11 +2563,11 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2589,10 +2589,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2658,8 +2658,8 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2682,11 +2682,11 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2709,10 +2709,10 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2772,8 +2772,8 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2806,10 +2806,10 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2878,8 +2878,8 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2907,11 +2907,11 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2933,10 +2933,10 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3015,8 +3015,8 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3044,11 +3044,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3070,10 +3070,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3149,8 +3149,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3178,11 +3178,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3204,10 +3204,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3283,8 +3283,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3309,11 +3309,11 @@ func TestEntryProbeToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3336,11 +3336,11 @@ func TestEntryProbeToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + t.Fatalf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) } e.mu.Lock() @@ -3406,8 +3406,8 @@ func TestEntryProbeToFailed(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3449,10 +3449,10 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -3498,8 +3498,8 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0f545f255..e56a624fe 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -53,6 +53,8 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution + linkAddrCache *linkAddrCache + mu struct { sync.RWMutex spoofing bool @@ -138,7 +140,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.linkResQueue.init() + nic.linkResQueue.init(nic) + nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. @@ -167,7 +170,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) } nic.LinkEndpoint.Attach(nic) @@ -228,7 +231,9 @@ func (n *NIC) disableLocked() { // // This matches linux's behaviour at the time of writing: // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 - if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported { + switch err := n.clearNeighbors(); err.(type) { + case nil, *tcpip.ErrNotSupported: + default: panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) } @@ -243,7 +248,7 @@ func (n *NIC) disableLocked() { // address (ff02::1), start DAD for permanent addresses, and start soliciting // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. -func (n *NIC) enable() *tcpip.Error { +func (n *NIC) enable() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -263,7 +268,7 @@ func (n *NIC) enable() *tcpip.Error { // remove detaches NIC from the link endpoint and releases network endpoint // resources. This guarantees no packets between this NIC and the network // stack. -func (n *NIC) remove() *tcpip.Error { +func (n *NIC) remove() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -299,48 +304,69 @@ func (n *NIC) IsLoopback() bool { } // WritePacket implements NetworkLinkEndpoint. -func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { - // As per relevant RFCs, we should queue packets while we wait for link - // resolution to complete. - // - // RFC 1122 section 2.3.2.2 (for IPv4): - // The link layer SHOULD save (rather than discard) at least - // one (the latest) packet of each set of packets destined to - // the same unresolved IP address, and transmit the saved - // packet when the address has been resolved. - // - // RFC 4861 section 7.2.2 (for IPv6): - // While waiting for address resolution to complete, the sender MUST, for - // each neighbor, retain a small queue of packets waiting for address - // resolution to complete. The queue MUST hold at least one packet, and MAY - // contain more. However, the number of queued packets per neighbor SHOULD - // be limited to some small value. When a queue overflows, the new arrival - // SHOULD replace the oldest entry. Once address resolution completes, the - // node transmits any queued packets. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - r.Acquire() - n.linkResQueue.enqueue(ch, r, protocol, pkt) - return nil +func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { + _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) + return err +} + +func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + switch pkt := pkt.(type) { + case *PacketBuffer: + if err := n.writePacket(r, gso, protocol, pkt); err != nil { + return 0, err } - return err + return 1, nil + case *PacketBufferList: + return n.writePackets(r, gso, protocol, *pkt) + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) } +} - return n.writePacket(r.Fields(), gso, protocol, pkt) +func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + routeInfo, _, err := r.resolvedFields(nil) + switch err.(type) { + case nil: + return n.writePacketBuffer(routeInfo, gso, protocol, pkt) + case *tcpip.ErrWouldBlock: + // As per relevant RFCs, we should queue packets while we wait for link + // resolution to complete. + // + // RFC 1122 section 2.3.2.2 (for IPv4): + // The link layer SHOULD save (rather than discard) at least + // one (the latest) packet of each set of packets destined to + // the same unresolved IP address, and transmit the saved + // packet when the address has been resolved. + // + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and + // MAY contain more. However, the number of queued packets per neighbor + // SHOULD be limited to some small value. When a queue overflows, the new + // arrival SHOULD replace the oldest entry. Once address resolution + // completes, the node transmits any queued packets. + return n.linkResQueue.enqueue(r, gso, protocol, pkt) + default: + return 0, err + } } // WritePacketToRemote implements NetworkInterface. -func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr return n.writePacket(r, gso, protocol, pkt) } -func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { return err } @@ -351,10 +377,18 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution - // is being peformed like WritePacket. - writtenPackets, err := n.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) +func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return n.enqueuePacketBuffer(r, gso, protocol, &pkts) +} + +func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol + } + + writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { @@ -463,15 +497,15 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { +func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) @@ -535,63 +569,70 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit } // removeAddress removes an address from n. -func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { for _, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { continue } - if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + switch err := addressableEndpoint.RemovePermanentAddress(addr); err.(type) { + case *tcpip.ErrBadLocalAddress: continue - } else { + default: return err } } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} +} + +func (n *NIC) confirmReachable(addr tcpip.Address) { + if n := n.neigh; n != nil { + n.handleUpperLevelConfirmation(addr) + } } -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { if n.neigh != nil { entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) return entry.LinkAddr, ch, err } - return n.stack.linkAddrCache.get(tcpip.FullAddress{NIC: n.ID(), Addr: addr}, linkRes, localAddr, n, onResolve) + return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve) } -func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { +func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { if n.neigh == nil { - return nil, tcpip.ErrNotSupported + return nil, &tcpip.ErrNotSupported{} } return n.neigh.entries(), nil } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { +func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } n.neigh.addStaticEntry(addr, linkAddress) return nil } -func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } if !n.neigh.removeEntry(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } return nil } -func (n *NIC) clearNeighbors() *tcpip.Error { +func (n *NIC) clearNeighbors() tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } n.neigh.clear() @@ -600,7 +641,7 @@ func (n *NIC) clearNeighbors() *tcpip.Error { // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. -func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as @@ -608,12 +649,12 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address ep, ok := n.networkEndpoints[protocol] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } gep, ok := ep.(GroupAddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } return gep.JoinGroup(addr) @@ -621,15 +662,15 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { ep, ok := n.networkEndpoints[protocol] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } gep, ok := ep.(GroupAddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } return gep.LeaveGroup(addr) @@ -879,9 +920,9 @@ func (n *NIC) Name() string { } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { +func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { if n.neigh == nil { - return NUDConfigurations{}, tcpip.ErrNotSupported + return NUDConfigurations{}, &tcpip.ErrNotSupported{} } return n.neigh.config(), nil } @@ -890,22 +931,22 @@ func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { +func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } c.resetInvalidFields() n.neigh.setConfig(c) return nil } -func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { n.mu.Lock() defer n.mu.Unlock() eps, ok := n.mu.packetEPs[netProto] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } eps.add(ep) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 5b5c58afb..2f719fbe5 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -39,7 +39,7 @@ type testIPv6Endpoint struct { invalidatedRtr tcpip.Address } -func (*testIPv6Endpoint) Enable() *tcpip.Error { +func (*testIPv6Endpoint) Enable() tcpip.Error { return nil } @@ -65,21 +65,21 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { } // WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error { return nil } // WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { // Our tests don't use this so we don't support it. - return 0, tcpip.ErrNotSupported + return 0, &tcpip.ErrNotSupported{} } // WriteHeaderIncludedPacket implements // NetworkEndpoint.WriteHeaderIncludedPacket. -func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error { // Our tests don't use this so we don't support it. - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } // HandlePacket implements NetworkEndpoint.HandlePacket. @@ -99,11 +99,20 @@ func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.invalidatedRtr = rtr } -var _ NetworkProtocol = (*testIPv6Protocol)(nil) +// Stats implements NetworkEndpoint. +func (*testIPv6Endpoint) Stats() NetworkEndpointStats { + return &testIPv6EndpointStats{} +} + +var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil) + +type testIPv6EndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) -// An IPv6 NetworkProtocol that supports the bare minimum to make a stack -// believe it supports IPv6. -// // We use this instead of ipv6.protocol because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Protocol struct{} @@ -140,12 +149,12 @@ func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, } // SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { return nil } // Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { return nil } @@ -160,15 +169,13 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } -var _ LinkAddressResolver = (*testIPv6Protocol)(nil) - // LinkAddressProtocol implements LinkAddressResolver. func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 12d67409a..77926e289 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -174,10 +174,6 @@ type NUDHandler interface { // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) - - // HandleUpperLevelConfirmation processes an incoming upper-level protocol - // (e.g. TCP acknowledgements) reachability confirmation. - HandleUpperLevelConfirmation(addr tcpip.Address) } // NUDConfigurations is the NUD configurations for the netstack. This is used diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 7bca1373e..ebfd5eb45 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -65,8 +65,9 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { // No NIC with ID 1 yet. config := stack.NUDConfigurations{} - if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + err := s.SetNUDConfigurations(1, config) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{}) } } @@ -90,8 +91,9 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) + _, err := s.NUDConfigurations(nicID) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) } } @@ -117,8 +119,9 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { } config := stack.NUDConfigurations{} - if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + err := s.SetNUDConfigurations(nicID, config) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) } } diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 41529ffd5..1c651e216 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -28,108 +28,205 @@ const ( maxPendingPacketsPerResolution = 256 ) +// pendingPacketBuffer is a pending packet buffer. +// +// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use +// WritePackets so we can use a PacketBufferList everywhere. +type pendingPacketBuffer interface { + len() int +} + +func (*PacketBuffer) len() int { + return 1 +} + +func (p *PacketBufferList) len() int { + return p.Len() +} + type pendingPacket struct { - route *Route - proto tcpip.NetworkProtocolNumber - pkt *PacketBuffer + routeInfo RouteInfo + gso *GSO + proto tcpip.NetworkProtocolNumber + pkt pendingPacketBuffer } // packetsPendingLinkResolution is a queue of packets pending link resolution. // // Once link resolution completes successfully, the packets will be written. type packetsPendingLinkResolution struct { - sync.Mutex + nic *NIC - // The packets to send once the resolver completes. - packets map[<-chan struct{}][]pendingPacket + mu struct { + sync.Mutex - // FIFO of channels used to cancel the oldest goroutine waiting for - // link-address resolution. - cancelChans []chan struct{} -} + // The packets to send once the resolver completes. + // + // The link resolution channel is used as the key for this map. + packets map[<-chan struct{}][]pendingPacket -func (f *packetsPendingLinkResolution) init() { - f.Lock() - defer f.Unlock() - f.packets = make(map[<-chan struct{}][]pendingPacket) + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + // + // cancelChans holds the same channels that are used as keys to packets. + cancelChans []<-chan struct{} + } } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - f.Lock() - defer f.Unlock() +func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { + n := uint64(pkt.len()) + f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n) - packets, ok := f.packets[ch] - if len(packets) == maxPendingPacketsPerResolution { - p := packets[0] - packets[0] = pendingPacket{} - packets = packets[1:] - p.route.Stats().IP.OutgoingPacketErrors.Increment() - p.route.Release() + if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { + ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n) } +} - if l := len(packets); l >= maxPendingPacketsPerResolution { - panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) +func (f *packetsPendingLinkResolution) init(nic *NIC) { + f.mu.Lock() + defer f.mu.Unlock() + f.nic = nic + f.mu.packets = make(map[<-chan struct{}][]pendingPacket) +} + +// dequeue any pending packets associated with ch. +// +// If success is true, packets will be written and sent to the given remote link +// address. +func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) { + f.mu.Lock() + packets, ok := f.mu.packets[ch] + delete(f.mu.packets, ch) + + if ok { + for i, cancelChan := range f.mu.cancelChans { + if cancelChan == ch { + f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...) + break + } + } } - f.packets[ch] = append(packets, pendingPacket{ - route: r, - proto: proto, - pkt: pkt, - }) + f.mu.Unlock() if ok { - return + f.dequeuePackets(packets, linkAddr, success) } +} - // Wait for the link-address resolution to complete. - cancel := f.newCancelChannelLocked() - go func() { - cancelled := false - select { - case <-ch: - case <-cancel: - cancelled = true - } +// enqueue a packet to be sent once link resolution completes. +// +// If the maximum number of pending resolutions is reached, the packets +// associated with the oldest link resolution will be dequeued as if they failed +// link resolution. +func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + f.mu.Lock() + // Make sure we attempt resolution while holding f's lock so that we avoid + // a race where link resolution completes before we enqueue the packets. + // + // A @ T1: Call ResolvedFields (get link resolution channel) + // B @ T2: Complete link resolution, dequeue pending packets + // C @ T1: Enqueue packet that already completed link resolution (which will + // never dequeue) + // + // To make sure B does not interleave with A and C, we make sure A and C are + // done while holding the lock. + routeInfo, ch, err := r.resolvedFields(nil) + switch err.(type) { + case nil: + // The route resolved immediately, so we don't need to wait for link + // resolution to send the packet. + f.mu.Unlock() + return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt) + case *tcpip.ErrWouldBlock: + // We need to wait for link resolution to complete. + default: + f.mu.Unlock() + return 0, err + } - f.Lock() - packets, ok := f.packets[ch] - delete(f.packets, ch) - f.Unlock() + defer f.mu.Unlock() - if !ok { - panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) - } + packets, ok := f.mu.packets[ch] + packets = append(packets, pendingPacket{ + routeInfo: routeInfo, + gso: gso, + proto: proto, + pkt: pkt, + }) - for _, p := range packets { - if cancelled || p.route.IsResolutionRequired() { - p.route.Stats().IP.OutgoingPacketErrors.Increment() + if len(packets) > maxPendingPacketsPerResolution { + f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt) + packets[0] = pendingPacket{} + packets = packets[1:] - if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { - linkResolvableEP.HandleLinkResolutionFailure(pkt) - } - } else { - p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt) - } - p.route.Release() + if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution { + panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution)) } - }() + } + + f.mu.packets[ch] = packets + + if ok { + return pkt.len(), nil + } + + cancelledPackets := f.newCancelChannelLocked(ch) + + if len(cancelledPackets) != 0 { + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as handing link resolution failures may be a costly operation. + go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */) + } + + return pkt.len(), nil } -// newCancelChannel creates a channel that can cancel a pending forwarding -// activity. The oldest channel is closed if the number of open channels would -// exceed maxPendingResolutions. -func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { - if len(f.cancelChans) == maxPendingResolutions { - ch := f.cancelChans[0] - f.cancelChans[0] = nil - f.cancelChans = f.cancelChans[1:] - close(ch) +// newCancelChannelLocked appends the link resolution channel to a FIFO. If the +// maximum number of pending resolutions is reached, the oldest channel will be +// removed and its associated pending packets will be returned. +func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket { + f.mu.cancelChans = append(f.mu.cancelChans, newCH) + if len(f.mu.cancelChans) <= maxPendingResolutions { + return nil } - if l := len(f.cancelChans); l >= maxPendingResolutions { + + ch := f.mu.cancelChans[0] + f.mu.cancelChans[0] = nil + f.mu.cancelChans = f.mu.cancelChans[1:] + if l := len(f.mu.cancelChans); l > maxPendingResolutions { panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) } - ch := make(chan struct{}) - f.cancelChans = append(f.cancelChans, ch) - return ch + packets, ok := f.mu.packets[ch] + if !ok { + panic("must have a packet queue for an uncancelled channel") + } + delete(f.mu.packets, ch) + + return packets +} + +func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) { + for _, p := range packets { + if success { + p.routeInfo.RemoteLinkAddress = linkAddr + _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + } else { + f.incrementOutgoingPacketErrors(p.proto, p.pkt) + + if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok { + switch pkt := p.pkt.(type) { + case *PacketBuffer: + linkResolvableEP.HandleLinkResolutionFailure(pkt) + case *PacketBufferList: + for pb := pkt.Front(); pb != nil; pb = pb.Next() { + linkResolvableEP.HandleLinkResolutionFailure(pb) + } + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) + } + } + } + } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 68c113b6a..510da8689 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -172,10 +172,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -184,7 +184,7 @@ type TransportProtocol interface { // ParsePorts returns the source and destination ports stored in a // packet of this protocol. - ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) + ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this // protocol that don't match any existing endpoint. For example, @@ -197,12 +197,12 @@ type TransportProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -289,10 +289,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - JoinGroup(group tcpip.Address) *tcpip.Error + JoinGroup(group tcpip.Address) tcpip.Error // LeaveGroup attempts to leave the specified group. - LeaveGroup(group tcpip.Address) *tcpip.Error + LeaveGroup(group tcpip.Address) tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool @@ -440,17 +440,17 @@ func (k AddressKind) IsPermanent() bool { type AddressableEndpoint interface { // AddAndAcquirePermanentAddress adds the passed permanent address. // - // Returns tcpip.ErrDuplicateAddress if the address exists. + // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. // - // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed + // Returns *tcpip.ErrBadLocalAddress if the endpoint does not have the passed // permanent address. - RemovePermanentAddress(addr tcpip.Address) *tcpip.Error + RemovePermanentAddress(addr tcpip.Address) tcpip.Error // MainAddress returns the endpoint's primary permanent address. MainAddress() tcpip.AddressWithPrefix @@ -512,14 +512,14 @@ type NetworkInterface interface { Promiscuous() bool // WritePacketToRemote writes the packet to the given remote link address. - WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePacket writes a packet with the given protocol through the given // route. // // WritePacket takes ownership of the packet buffer. The packet buffer's // network and transport header must be set. - WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. @@ -529,7 +529,7 @@ type NetworkInterface interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // LinkResolvableNetworkEndpoint handles link resolution events. @@ -547,8 +547,8 @@ type NetworkEndpoint interface { // Must only be called when the stack is in a state that allows the endpoint // to send and receive packets. // - // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled. - Enable() *tcpip.Error + // Returns *tcpip.ErrNotPermitted if the endpoint cannot be enabled. + Enable() tcpip.Error // Enabled returns true if the endpoint is enabled. Enabled() bool @@ -574,16 +574,16 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and // underlying packets. - WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. - WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. @@ -597,6 +597,26 @@ type NetworkEndpoint interface { // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for // this endpoint. NetworkProtocolNumber() tcpip.NetworkProtocolNumber + + // Stats returns a reference to the network endpoint stats. + Stats() NetworkEndpointStats +} + +// NetworkEndpointStats is the interface implemented by each network endpoint +// stats struct. +type NetworkEndpointStats interface { + // IsNetworkEndpointStats is an empty method to implement the + // NetworkEndpointStats marker interface. + IsNetworkEndpointStats() +} + +// IPNetworkEndpointStats is a NetworkEndpointStats that tracks IP-related +// statistics. +type IPNetworkEndpointStats interface { + NetworkEndpointStats + + // IPStats returns the IP statistics of a network endpoint. + IPStats() *tcpip.IPStats } // ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. @@ -634,12 +654,12 @@ type NetworkProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error + SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error + Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -776,7 +796,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. @@ -786,7 +806,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -801,7 +821,7 @@ type InjectableLinkEndpoint interface { // link. // // dest is used by endpoints with multiple raw destinations. - InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error + InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } // A LinkAddressResolver is an extension to a NetworkProtocol that @@ -813,7 +833,7 @@ type LinkAddressResolver interface { // // The request is sent from the passed network interface. If the interface // local address is unspecified, any interface local address may be used. - LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the @@ -829,12 +849,8 @@ type LinkAddressResolver interface { // A LinkAddressCache caches link addresses. type LinkAddressCache interface { - // CheckLocalAddress determines if the given local address exists, and if it - // does not exist. - CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID - // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress) } // RawFactory produces endpoints for writing various types of raw packets. @@ -842,11 +858,11 @@ type RawFactory interface { // NewUnassociatedEndpoint produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can // be used to write arbitrary packets that include the network header. - NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // NewPacketEndpoint produces endpoints for reading and writing packets // that include network and (when cooked is false) link layer headers. - NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) } // GSOType is the type of GSO segments. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1ff7b3a37..4ae0f2a1a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -86,12 +86,21 @@ type RouteInfo struct { RemoteLinkAddress tcpip.LinkAddress } -// Fields returns a RouteInfo with all of r's exported fields. This allows -// callers to store the route's fields without retaining a reference to it. +// Fields returns a RouteInfo with all of the known values for the route's +// fields. +// +// If any fields are unknown (e.g. remote link address when it is waiting for +// link address resolution), they will be unset. func (r *Route) Fields() RouteInfo { + r.mu.RLock() + defer r.mu.RUnlock() + return r.fieldsLocked() +} + +func (r *Route) fieldsLocked() RouteInfo { return RouteInfo{ routeInfo: r.routeInfo, - RemoteLinkAddress: r.RemoteLinkAddress(), + RemoteLinkAddress: r.mu.remoteLinkAddress, } } @@ -306,32 +315,43 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. +// ResolvedFieldsResult is the result of a route resolution attempt. +type ResolvedFieldsResult struct { + RouteInfo RouteInfo + Success bool +} + +// ResolvedFields attempts to resolve the remote link address if it is not +// known. // -// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. -// waiting for ARP reply). If address resolution is required, a notification -// channel is also returned for the caller to block on. The channel is closed -// once address resolution is complete (successful or not). If a callback is -// provided, it will be called when address resolution is complete, regardless +// If a callback is provided, it will be called before ResolvedFields returns +// when address resolution is not required. If address resolution is required, +// the callback will be called once address resolution is complete, regardless // of success or failure. -func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { - r.mu.Lock() - - if !r.isResolutionRequiredRLocked() { - // Nothing to do if there is no cache (which does the resolution on cache miss) or - // link address is already known. - r.mu.Unlock() - return nil, nil - } - - // Increment the route's reference count because finishResolution retains a - // reference to the route and releases it when called. - r.acquireLocked() - r.mu.Unlock() +// +// Note, the route will not cache the remote link address when address +// resolution completes. +func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) tcpip.Error { + _, _, err := r.resolvedFields(afterResolve) + return err +} - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress +// resolvedFields is like ResolvedFields but also returns a notification channel +// when address resolution is required. This channel will become readable once +// address resolution is complete. +// +// The route's fields will also be returned, regardless of whether address +// resolution is required or not. +func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, tcpip.Error) { + r.mu.RLock() + fields := r.fieldsLocked() + resolutionRequired := r.isResolutionRequiredRLocked() + r.mu.RUnlock() + if !resolutionRequired { + if afterResolve != nil { + afterResolve(ResolvedFieldsResult{RouteInfo: fields, Success: true}) + } + return fields, nil, nil } // If specified, the local address used for link address resolution must be an @@ -341,18 +361,27 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } - finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { - if ok { - r.ResolveWith(linkAddress) - } + afterResolveFields := fields + linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { if afterResolve != nil { - afterResolve() + if r.Success { + afterResolveFields.RemoteLinkAddress = r.LinkAddress + } + + afterResolve(ResolvedFieldsResult{RouteInfo: afterResolveFields, Success: r.Success}) } - r.Release() + }) + if err == nil { + fields.RemoteLinkAddress = linkAddr } + return fields, ch, err +} - _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) - return ch, err +func (r *Route) nextHop() tcpip.Address { + if len(r.NextHop) == 0 { + return r.RemoteAddress + } + return r.NextHop } // local returns true if the route is a local route. @@ -371,11 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { - return false - } - - return r.linkRes != nil + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -404,9 +429,9 @@ func (r *Route) isValidForOutgoingRLocked() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) @@ -414,9 +439,9 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf // WritePackets writes a list of n packets through the given route and returns // the number of packets written. -func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { if !r.isValidForOutgoing() { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) @@ -424,9 +449,9 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) @@ -496,3 +521,12 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } + +// ConfirmReachable informs the network/link layer that the neighbour used for +// the route is reachable. +// +// "Reachable" is defined as having full-duplex communication between the +// local and remote ends of the route. +func (r *Route) ConfirmReachable() { + r.outgoingNIC.confirmReachable(r.nextHop()) +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c0aec61a6..119c4c505 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -76,12 +76,16 @@ type TCPCubicState struct { // TCPRACKState is used to hold a copy of the internal RACK state when the // TCPProbeFunc is invoked. type TCPRACKState struct { - XmitTime time.Time - EndSequence seqnum.Value - FACK seqnum.Value - RTT time.Duration - Reord bool - DSACKSeen bool + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool + DSACKSeen bool + ReoWnd time.Duration + ReoWndIncr uint8 + ReoWndPersist int8 + RTTSeq seqnum.Value } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -382,8 +386,6 @@ type Stack struct { stats tcpip.Stats - linkAddrCache *linkAddrCache - mu sync.RWMutex nics map[tcpip.NICID]*NIC @@ -446,7 +448,7 @@ type Stack struct { // sendBufferSize holds the min/default/max send buffer sizes for // endpoints other than TCP. - sendBufferSize SendBufferSizeOption + sendBufferSize tcpip.SendBufferSizeOption // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. @@ -554,7 +556,7 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { netProto := t.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: @@ -572,11 +574,11 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl switch len(t.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { - return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{} } case header.IPv6AddressSize: if len(addr.Addr) == header.IPv4AddressSize { - return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + return tcpip.FullAddress{}, 0, &tcpip.ErrNetworkUnreachable{} } } @@ -584,10 +586,10 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl case netProto == t.NetProto: case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: if v6only { - return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + return tcpip.FullAddress{}, 0, &tcpip.ErrNoRoute{} } default: - return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{} } return addr, netProto, nil @@ -636,7 +638,6 @@ func New(opts Options) *Stack { linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), cleanupEndpoints: make(map[TransportEndpoint]struct{}), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), PortManager: ports.NewPortManager(), clock: clock, stats: opts.Stats.FillIn(), @@ -649,7 +650,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), - sendBufferSize: SendBufferSizeOption{ + sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, @@ -701,10 +702,10 @@ func (s *Stack) UniqueID() uint64 { // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return netProto.SetOption(option) } @@ -718,10 +719,10 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op // if err != nil { // ... // } -func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return netProto.Option(option) } @@ -730,10 +731,10 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return transProtoState.proto.SetOption(option) } @@ -745,10 +746,10 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return transProtoState.proto.Option(option) } @@ -781,15 +782,15 @@ func (s *Stack) Stats() tcpip.Stats { // SetForwarding enables or disables packet forwarding between NICs for the // passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error { +func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { protocol, ok := s.networkProtocols[protocolNum] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } forwardingProtocol.SetForwarding(enable) @@ -852,10 +853,10 @@ func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { } // NewEndpoint creates a new transport layer endpoint of the given protocol. -func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { t, ok := s.transportProtocols[transport] if !ok { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return t.proto.NewEndpoint(network, waiterQueue) @@ -864,9 +865,9 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // NewRawEndpoint creates a new raw transport layer endpoint of the given // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. -func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { if s.rawFactory == nil { - return nil, tcpip.ErrNotPermitted + return nil, &tcpip.ErrNotPermitted{} } if !associated { @@ -875,7 +876,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network t, ok := s.transportProtocols[transport] if !ok { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return t.proto.NewRawEndpoint(network, waiterQueue) @@ -883,9 +884,9 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // NewPacketEndpoint creates a new packet endpoint listening for the given // netProto. -func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if s.rawFactory == nil { - return nil, tcpip.ErrNotPermitted + return nil, &tcpip.ErrNotPermitted{} } return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) @@ -916,20 +917,20 @@ type NICOptions struct { // NICs can be configured. // // LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. -func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { +func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() // Make sure id is unique. if _, ok := s.nics[id]; ok { - return tcpip.ErrDuplicateNICID + return &tcpip.ErrDuplicateNICID{} } // Make sure name is unique, unless unnamed. if opts.Name != "" { for _, n := range s.nics { if n.Name() == opts.Name { - return tcpip.ErrDuplicateNICID + return &tcpip.ErrDuplicateNICID{} } } } @@ -945,7 +946,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp // CreateNIC creates a NIC with the provided id and LinkEndpoint and calls // LinkEndpoint.Attach to bind ep with a NetworkDispatcher. -func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } @@ -963,26 +964,26 @@ func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. -func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) EnableNIC(id tcpip.NICID) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.enable() } // DisableNIC disables the given NIC. -func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) DisableNIC(id tcpip.NICID) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.disable() @@ -1003,7 +1004,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // RemoveNIC removes NIC and all related routes from the network stack. -func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) RemoveNIC(id tcpip.NICID) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -1013,10 +1014,10 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { // removeNICLocked removes NIC and all related routes from the network stack. // // s.mu must be locked. -func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { +func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } delete(s.nics, id) @@ -1050,6 +1051,9 @@ type NICInfo struct { Stats NICStats + // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. + NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats + // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext @@ -1081,6 +1085,12 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } + + netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats) + for proto, netEP := range nic.networkEndpoints { + netStats[proto] = netEP.Stats() + } + nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.LinkEndpoint.LinkAddress(), @@ -1088,6 +1098,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.LinkEndpoint.MTU(), Stats: nic.stats, + NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), } @@ -1111,13 +1122,13 @@ type NICStateFlags struct { } // AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } // AddAddressWithPrefix is the same as AddAddress, but allows you to specify // the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error { +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { ap := tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: addr, @@ -1127,16 +1138,16 @@ func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProto // AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) } // AddAddressWithOptions is the same as AddAddress, but allows you to specify // whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { +func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { netProto, ok := s.networkProtocols[protocol] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ Protocol: protocol, @@ -1149,13 +1160,13 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt // AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows // you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { +func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.addAddress(protocolAddress, peb) @@ -1163,7 +1174,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc // RemoveAddress removes an existing network-layer address from the specified // NIC. -func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -1171,7 +1182,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { return nic.removeAddress(addr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // AllAddresses returns a map of NICIDs to their protocol addresses (primary @@ -1189,19 +1200,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { // GetMainNICAddress returns the first non-deprecated primary address and prefix // for the given NIC and protocol. If no non-deprecated primary address exists, -// a deprecated primary address and prefix will be returned. Returns an error if +// a deprecated primary address and prefix will be returned. Returns false if // the NIC doesn't exist and an empty value if the NIC doesn't have a primary // address for the given protocol. -func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + return tcpip.AddressWithPrefix{}, false } - return nic.primaryAddress(protocol), nil + return nic.primaryAddress(protocol), true } func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { @@ -1301,7 +1312,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1337,9 +1348,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return nil, tcpip.ErrBadLocalAddress + return nil, &tcpip.ErrBadLocalAddress{} } - return nil, tcpip.ErrNetworkUnreachable + return nil, &tcpip.ErrNetworkUnreachable{} } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal @@ -1405,7 +1416,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } - return nil, tcpip.ErrNoRoute + return nil, &tcpip.ErrNoRoute{} } if id == 0 { @@ -1425,12 +1436,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return nil, tcpip.ErrNoRoute + return nil, &tcpip.ErrNoRoute{} } if header.IsV6LoopbackAddress(remoteAddr) { - return nil, tcpip.ErrBadLocalAddress + return nil, &tcpip.ErrBadLocalAddress{} } - return nil, tcpip.ErrNetworkUnreachable + return nil, &tcpip.ErrNetworkUnreachable{} } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1476,13 +1487,13 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto } // SetPromiscuousMode enables or disables promiscuous mode in the given NIC. -func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error { +func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.setPromiscuousMode(enable) @@ -1492,13 +1503,13 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error // SetSpoofing enables or disables address spoofing in the given NIC, allowing // endpoints to bind to any address in the NIC. -func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { +func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.setSpoofing(enable) @@ -1506,17 +1517,27 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { return nil } -// AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.add(fullAddr, linkAddr) - // TODO: provide a way for a transport endpoint to receive a signal - // that AddLinkAddress for a particular address has been called. +// AddLinkAddress adds a link address for the neighbor on the specified NIC. +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[nicID] + if !ok { + return &tcpip.ErrUnknownNICID{} + } + + nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr) + return nil +} + +// LinkResolutionResult is the result of a link address resolution attempt. +type LinkResolutionResult struct { + LinkAddress tcpip.LinkAddress + Success bool } -// GetLinkAddress finds the link address corresponding to a neighbor's address. -// -// Returns a link address for the remote address, if readily available. +// GetLinkAddress finds the link address corresponding to a network address. // // Returns ErrNotSupported if the stack is not configured with a link address // resolver for the specified network protocol. @@ -1525,53 +1546,56 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t // with a notification channel for the caller to block on. Triggers address // resolution asynchronously. // -// If onResolve is provided, it will be called either immediately, if -// resolution is not required, or when address resolution is complete, with -// the resolved link address and whether resolution succeeded. After any -// callbacks have been called, the returned notification channel is closed. +// onResolve will be called either immediately, if resolution is not required, +// or when address resolution is complete, with the resolved link address and +// whether resolution succeeded. // // If specified, the local address must be an address local to the interface // the neighbor cache belongs to. The local address is the source address of // a packet prompting NUD/link address resolution. -// -// TODO(gvisor.dev/issue/5151): Don't return the link address. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return "", nil, tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } linkRes, ok := s.linkAddrResolvers[protocol] if !ok { - return "", nil, tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} + } + + if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil } - return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + return err } // Neighbors returns all IP to MAC address associations. -func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { +func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return nil, tcpip.ErrUnknownNICID + return nil, &tcpip.ErrUnknownNICID{} } return nic.neighbors() } // AddStaticNeighbor statically associates an IP address to a MAC address. -func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.addStaticNeighbor(addr, linkAddr) @@ -1580,26 +1604,26 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd // RemoveNeighbor removes an IP to MAC address association previously created // either automically or by AddStaticNeighbor. Returns ErrBadAddress if there // is no association with the provided address. -func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error { +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.removeNeighbor(addr) } // ClearNeighbors removes all IP to MAC address associations. -func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { +func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.clearNeighbors() @@ -1609,25 +1633,25 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CheckRegisterTransportEndpoint checks if an endpoint can be registered with // the stack transport dispatcher. -func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { +func (s *Stack) UnregisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. -func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { +func (s *Stack) StartTransportEndpointCleanup(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.cleanupEndpointsMu.Lock() s.cleanupEndpoints[ep] = struct{}{} s.cleanupEndpointsMu.Unlock() @@ -1652,13 +1676,13 @@ func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, tran // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. -func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. -func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { +func (s *Stack) UnregisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { s.demux.unregisterRawEndpoint(netProto, transProto, ep) } @@ -1762,7 +1786,7 @@ func (s *Stack) Resume() { // RegisterPacketEndpoint registers ep with the stack, causing it to receive // all traffic of the specified netProto on the given NIC. If nicID is 0, it // receives traffic from every NIC. -func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { +func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -1781,7 +1805,7 @@ func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.Network // Capture on a specific device. nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } if err := nic.registerPacketEndpoint(netProto, ep); err != nil { return err @@ -1819,12 +1843,12 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip // WritePacketToRemote writes a payload on the specified NIC using the provided // network protocol and remote link address. -func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } pkt := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: int(nic.MaxHeaderLength()), @@ -1889,37 +1913,37 @@ func (s *Stack) RemoveTCPProbe() { } // JoinGroup joins the given multicast group on the given NIC. -func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { +func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.joinGroup(protocol, multicastAddr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // LeaveGroup leaves the given multicast group on the given NIC. -func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { +func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.leaveGroup(protocol, multicastAddr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // IsInGroup returns true if the NIC with ID nicID has joined the multicast // group multicastAddr. -func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.isInGroup(multicastAddr), nil } - return false, tcpip.ErrUnknownNICID + return false, &tcpip.ErrUnknownNICID{} } // IPTables returns the stack's iptables. @@ -1959,26 +1983,26 @@ func (s *Stack) AllowICMPMessage() bool { // GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol // number installed on the specified NIC. -func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) { +func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() nic, ok := s.nics[nicID] if !ok { - return nil, tcpip.ErrUnknownNICID + return nil, &tcpip.ErrUnknownNICID{} } return nic.getNetworkEndpoint(proto), nil } // NUDConfigurations gets the per-interface NUD configurations. -func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { +func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() if !ok { - return NUDConfigurations{}, tcpip.ErrUnknownNICID + return NUDConfigurations{}, &tcpip.ErrUnknownNICID{} } return nic.nudConfigs() @@ -1988,13 +2012,13 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.setNUDConfigs(c) @@ -2036,7 +2060,7 @@ func generateRandInt64() int64 { } // FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -2048,7 +2072,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres addressEndpoint.DecRef() return nic.getNetworkEndpoint(netProto), nil } - return nil, tcpip.ErrBadAddress + return nil, &tcpip.ErrBadAddress{} } // FindNICNameFromID returns the name of the NIC for the given NICID. diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 0b093e6c5..8d9b20b7e 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -14,7 +14,9 @@ package stack -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // MinBufferSize is the smallest size of a receive or send buffer. @@ -29,14 +31,6 @@ const ( DefaultMaxBufferSize = 4 << 20 // 4 MiB ) -// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to -// get/set the default, min and max send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - // ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to // get/set the default, min and max receive buffer sizes. type ReceiveBufferSizeOption struct { @@ -46,17 +40,17 @@ type ReceiveBufferSizeOption struct { } // SetOption allows setting stack wide options. -func (s *Stack) SetOption(option interface{}) *tcpip.Error { +func (s *Stack) SetOption(option interface{}) tcpip.Error { switch v := option.(type) { - case SendBufferSizeOption: + case tcpip.SendBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } if v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } s.mu.Lock() @@ -68,11 +62,11 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } if v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } s.mu.Lock() @@ -81,14 +75,14 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option allows retrieving stack wide options. -func (s *Stack) Option(option interface{}) *tcpip.Error { +func (s *Stack) Option(option interface{}) tcpip.Error { switch v := option.(type) { - case *SendBufferSizeOption: + case *tcpip.SendBufferSizeOption: s.mu.RLock() *v = s.sendBufferSize s.mu.RUnlock() @@ -101,6 +95,6 @@ func (s *Stack) Option(option interface{}) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7e935ddff..41f95811f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,6 +60,15 @@ const ( protocolNumberOffset = 2 ) +func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { + if addr, ok := s.GetMainNICAddress(nicID, proto); !ok { + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto) + } else if addr != want { + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want) + } + return nil +} + // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and // received packets; the counts of all endpoints are aggregated in the protocol // descriptor. @@ -81,7 +90,7 @@ type fakeNetworkEndpoint struct { dispatcher stack.TransportDispatcher } -func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { +func (f *fakeNetworkEndpoint) Enable() tcpip.Error { f.mu.Lock() defer f.mu.Unlock() f.mu.enabled = true @@ -145,7 +154,7 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { +func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { return 0 } @@ -153,7 +162,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -176,18 +185,30 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func (f *fakeNetworkEndpoint) Close() { f.AddressableEndpointState.Cleanup() } +// Stats implements NetworkEndpoint. +func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats { + return &fakeNetworkEndpointStats{} +} + +var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil) + +type fakeNetworkEndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} + // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. @@ -202,15 +223,15 @@ type fakeNetworkProtocol struct { } } -func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { +func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fakeNetNumber } -func (f *fakeNetworkProtocol) MinimumPacketSize() int { +func (*fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } -func (f *fakeNetworkProtocol) DefaultPrefixLen() int { +func (*fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } @@ -232,23 +253,23 @@ func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.Li return e } -func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: f.defaultTTL = uint8(*v) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(f.defaultTTL) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -397,7 +418,7 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return err @@ -406,7 +427,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r *stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -435,14 +456,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -579,8 +600,8 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{}) } } @@ -628,8 +649,9 @@ func TestDisableUnknownNIC(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + err := s.DisableNIC(1) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) } } @@ -687,8 +709,9 @@ func TestRemoveUnknownNIC(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + err := s.RemoveNIC(1) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) } } @@ -731,8 +754,8 @@ func TestRemoveNIC(t *testing.T) { func TestRouteWithDownNIC(t *testing.T) { tests := []struct { name string - downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error - upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error }{ { name: "Disabled NIC", @@ -890,15 +913,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -911,7 +934,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1036,11 +1059,12 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) + err := s.RemoveAddress(1, localAddr) + if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) } } @@ -1087,12 +1111,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) + { + err := s.RemoveAddress(1, localAddr) + if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) + } } } @@ -1186,7 +1213,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1214,7 +1241,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1253,8 +1280,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1290,7 +1317,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1333,8 +1360,8 @@ func TestPromiscuousMode(t *testing.T) { // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{}) } // Set promiscuous mode to false, then check that packet can't be @@ -1540,7 +1567,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1590,8 +1617,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + { + _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) + } } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} @@ -1610,8 +1640,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + { + _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) + } } } @@ -1753,9 +1786,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { anyAddr = header.IPv6Any } - want := tcpip.ErrNetworkUnreachable + var want tcpip.Error = &tcpip.ErrNetworkUnreachable{} if tc.routeNeeded { - want = tcpip.ErrNoRoute + want = &tcpip.ErrNoRoute{} } // If there is no endpoint, it won't work. @@ -1769,8 +1802,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { // Route table is empty but we need a route, this should cause an error. - if err != tcpip.ErrNoRoute { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{}) } } else { if err != nil { @@ -1861,20 +1894,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) + gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if len(primaryAddrAdded) == 0 { // No primary addresses present. if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr) } } else { // At least one primary address was added, verify the returned // address is in the list of primary addresses we added. if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded) } } }) @@ -1915,12 +1948,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { - t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { @@ -1928,12 +1957,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get no address after removal. - gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } }) } @@ -2102,7 +2127,7 @@ func TestCreateNICWithOptions(t *testing.T) { type callArgsAndExpect struct { nicID tcpip.NICID opts stack.NICOptions - err *tcpip.Error + err tcpip.Error } tests := []struct { @@ -2120,7 +2145,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(1), opts: stack.NICOptions{Name: "eth2"}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2135,7 +2160,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(2), opts: stack.NICOptions{Name: "lo"}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2165,7 +2190,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(1), opts: stack.NICOptions{}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2474,12 +2499,12 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } } - gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + // Check that we get no address after removal. + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } - if gotMainAddr != expectedMainAddr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) + if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { + t.Fatal(err) } }) } @@ -2525,12 +2550,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } }) } @@ -2561,12 +2582,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // Address should not be considered bound to the // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } linkLocalAddr := header.LinkLocalAddr(linkAddr1) @@ -2584,12 +2601,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { + t.Fatal(err) } } @@ -2621,17 +2634,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { t.Fatal("AddAddressWithOptions failed:", err) } - addr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + addr, ok := s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if pi == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) } } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) } { @@ -2710,18 +2723,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { if err := s.RemoveAddress(1, "\x03"); err != nil { t.Fatalf("RemoveAddress failed: %v", err) } - addr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatalf("s.GetMainNICAddress failed: %v", err) + addr, ok = s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) - + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) } } else { if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) } } }) @@ -3247,12 +3259,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Address should be tentative so it should not be a main address. - got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Enabling the NIC should start DAD for the address. @@ -3264,12 +3272,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Address should not be considered bound to the NIC yet (DAD ongoing). - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Wait for DAD to resolve. @@ -3284,12 +3288,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { + t.Fatal(err) } // Enabling the NIC again should be a no-op. @@ -3299,12 +3299,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { + t.Fatal(err) } } @@ -3313,14 +3309,14 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { testCases := []struct { name string rs stack.ReceiveBufferSizeOption - err *tcpip.Error + err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, @@ -3352,31 +3348,32 @@ func TestStackSendBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - ss stack.SendBufferSizeOption - err *tcpip.Error + ss tcpip.SendBufferSizeOption + err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations - {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{}) defer s.Close() - if err := s.SetOption(tc.ss); err != tc.err { - t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + err := s.SetOption(tc.ss) + if diff := cmp.Diff(tc.err, err); diff != "" { + t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff) } - var ss stack.SendBufferSizeOption if tc.err == nil { + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err != nil { t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) } @@ -3778,20 +3775,16 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } else if gotAddr != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } // Should still get the address when the NIC is diabled. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("DisableNIC(%d): %s", nicID, err) } - if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } else if gotAddr != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } } @@ -3939,7 +3932,7 @@ func TestFindRouteWithForwarding(t *testing.T) { addrNIC tcpip.NICID localAddr tcpip.Address - findRouteErr *tcpip.Error + findRouteErr tcpip.Error dependentOnForwarding bool }{ { @@ -3948,7 +3941,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID1, localAddr: fakeNetCfg.nic2Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3957,7 +3950,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: true, addrNIC: nicID1, localAddr: fakeNetCfg.nic2Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3966,7 +3959,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID1, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4002,7 +3995,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID2, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4011,7 +4004,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: true, addrNIC: nicID2, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4035,7 +4028,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4051,7 +4044,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, addrNIC: nicID1, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4059,7 +4052,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, addrNIC: nicID1, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4067,7 +4060,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4075,7 +4068,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4107,7 +4100,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4115,7 +4108,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4123,7 +4116,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4131,7 +4124,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4186,8 +4179,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if r != nil { defer r.Release() } - if err != test.findRouteErr { - t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + if diff := cmp.Diff(test.findRouteErr, err); diff != "" { + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { @@ -4234,8 +4227,11 @@ func TestFindRouteWithForwarding(t *testing.T) { if err := s.SetForwarding(test.netCfg.proto, false); err != nil { t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) } - if err := send(r, data); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + { + err := send(r, data) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{}) + } } if n := ep1.Drain(); n != 0 { t.Errorf("got %d unexpected packets from ep1", n) @@ -4297,8 +4293,9 @@ func TestWritePacketToRemote(t *testing.T) { } t.Run("InvalidNICID", func(t *testing.T) { - if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()) + if _, ok := err.(*tcpip.ErrUnknownDevice); !ok { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{}) } pkt, ok := e.Read() if got, want := ok, false; got != want { @@ -4372,10 +4369,64 @@ func TestGetLinkAddressErrors(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID) + { + err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{}) + } + } + { + err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{}) + } + } +} + +func TestStaticGetLinkAddress(t *testing.T) { + const ( + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + }) + if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4", + proto: ipv4.ProtocolNumber, + addr: header.IPv4Broadcast, + expectedLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "IPv6", + proto: ipv6.ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, } - if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch := make(chan stack.LinkResolutionResult, 1) + if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) { + ch <- r + }); err != nil { + t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err) + } + + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 07b2818d2..26eceb804 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -205,7 +205,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() @@ -222,7 +222,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -294,7 +294,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { for i, n := range netProtos { if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) @@ -306,7 +306,7 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // checkEndpoint checks if an endpoint can be registered with the dispatcher. -func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { for _, n := range netProtos { if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { return err @@ -403,7 +403,7 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() @@ -412,7 +412,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } } @@ -422,7 +422,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p return nil } -func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error { ep.mu.RLock() defer ep.mu.RUnlock() @@ -431,7 +431,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } } @@ -456,7 +456,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { if id.RemotePort != 0 { // SO_REUSEPORT only applies to bound/listening endpoints. flags.LoadBalanced = false @@ -464,7 +464,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } eps.mu.Lock() @@ -482,7 +482,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } -func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { if id.RemotePort != 0 { // SO_REUSEPORT only applies to bound/listening endpoints. flags.LoadBalanced = false @@ -490,7 +490,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } eps.mu.RLock() @@ -649,10 +649,10 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw // endpoint. -func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } eps.mu.Lock() diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 57e1f8354..10cbbe589 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -175,9 +175,9 @@ func TestTransportDemuxerRegister(t *testing.T) { for _, test := range []struct { name string proto tcpip.NetworkProtocolNumber - want *tcpip.Error + want tcpip.Error }{ - {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, {"success", ipv4.ProtocolNumber, nil}, } { t.Run(test.name, func(t *testing.T) { @@ -194,7 +194,7 @@ func TestTransportDemuxerRegister(t *testing.T) { if !ok { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) @@ -294,7 +294,7 @@ func TestBindToDeviceDistribution(t *testing.T) { defer wq.EventUnregister(&we) defer close(ch) - var err *tcpip.Error + var err tcpip.Error ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 9d39533a1..cf5de747b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "bytes" "io" "testing" @@ -67,9 +68,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { return &f.ops } -func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} - ep.ops.InitHandler(ep) +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} + ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) return ep } @@ -86,19 +87,20 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return mask } -func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { return tcpip.ReadResult{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, Data: buffer.View(v).ToVectorisedView(), @@ -112,42 +114,42 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions } // SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return -1, tcpip.ErrUnknownProtocolOption +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { + return -1, &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*fakeTransportEndpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error { f.peerAddr = addr.Addr // Find the route. r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { r.Release() return err @@ -162,22 +164,22 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 { return f.uniqueID } -func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { +func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error { return nil } -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { +func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return nil } func (*fakeTransportEndpoint) Reset() { } -func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { +func (*fakeTransportEndpoint) Listen(int) tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -186,9 +188,8 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai return a, nil, nil } -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { +func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error { if err := f.proto.stack.RegisterTransportEndpoint( - a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, @@ -202,11 +203,11 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { return nil } -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{}, nil } -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{}, nil } @@ -232,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress, route: route, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) f.acceptQueue = append(f.acceptQueue, ep) } @@ -251,7 +252,7 @@ func (*fakeTransportEndpoint) Resume(*stack.Stack) {} func (*fakeTransportEndpoint) Wait() {} -func (*fakeTransportEndpoint) LastError() *tcpip.Error { +func (*fakeTransportEndpoint) LastError() tcpip.Error { return nil } @@ -279,19 +280,19 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return newFakeTransportEndpoint(f, netProto, f.stack), nil } -func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return nil, tcpip.ErrUnknownProtocol +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return nil, &tcpip.ErrUnknownProtocol{} } func (*fakeTransportProtocol) MinimumPacketSize() int { return fakeTransHeaderLen } -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) { return 0, 0, nil } @@ -299,23 +300,23 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndp return stack.UnknownDestinationPacketHandled } -func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPModerateReceiveBufferOption: f.opts.good = bool(*v) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPModerateReceiveBufferOption: *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -520,8 +521,10 @@ func TestTransportSend(t *testing.T) { } // Create buffer that will hold the payload. - view := buffer.NewView(30) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + b := make([]byte, 30) + var r bytes.Reader + r.Reset(b) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("write failed: %v", err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 56aac093c..c500a0d1c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -29,6 +29,7 @@ package tcpip import ( + "bytes" "errors" "fmt" "io" @@ -46,141 +47,6 @@ import ( // Using header.IPv4AddressSize would cause an import cycle. const ipv4AddressSize = 4 -// Error represents an error in the netstack error space. Using a special type -// ensures that errors outside of this space are not accidentally introduced. -// -// All errors must have unique msg strings. -// -// +stateify savable -type Error struct { - msg string - - ignoreStats bool -} - -// String implements fmt.Stringer.String. -func (e *Error) String() string { - if e == nil { - return "<nil>" - } - return e.msg -} - -// IgnoreStats indicates whether this error type should be included in failure -// counts in tcpip.Stats structs. -func (e *Error) IgnoreStats() bool { - return e.ignoreStats -} - -// Errors that can be returned by the network stack. -var ( - ErrUnknownProtocol = &Error{msg: "unknown protocol"} - ErrUnknownNICID = &Error{msg: "unknown nic id"} - ErrUnknownDevice = &Error{msg: "unknown device"} - ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} - ErrDuplicateNICID = &Error{msg: "duplicate nic id"} - ErrDuplicateAddress = &Error{msg: "duplicate address"} - ErrNoRoute = &Error{msg: "no route"} - ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} - ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} - ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} - ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} - ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} - ErrNoPortAvailable = &Error{msg: "no ports are available"} - ErrPortInUse = &Error{msg: "port is in use"} - ErrBadLocalAddress = &Error{msg: "bad local address"} - ErrClosedForSend = &Error{msg: "endpoint is closed for send"} - ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} - ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} - ErrConnectionRefused = &Error{msg: "connection was refused"} - ErrTimeout = &Error{msg: "operation timed out"} - ErrAborted = &Error{msg: "operation aborted"} - ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} - ErrDestinationRequired = &Error{msg: "destination address is required"} - ErrNotSupported = &Error{msg: "operation not supported"} - ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} - ErrNotConnected = &Error{msg: "endpoint not connected"} - ErrConnectionReset = &Error{msg: "connection reset by peer"} - ErrConnectionAborted = &Error{msg: "connection aborted"} - ErrNoSuchFile = &Error{msg: "no such file"} - ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} - ErrBadAddress = &Error{msg: "bad address"} - ErrNetworkUnreachable = &Error{msg: "network is unreachable"} - ErrMessageTooLong = &Error{msg: "message too long"} - ErrNoBufferSpace = &Error{msg: "no buffer space available"} - ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} - ErrNotPermitted = &Error{msg: "operation not permitted"} - ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} - ErrMalformedHeader = &Error{msg: "header is malformed"} - ErrBadBuffer = &Error{msg: "bad buffer"} -) - -var messageToError map[string]*Error - -var populate sync.Once - -// StringToError converts an error message to the error. -func StringToError(s string) *Error { - populate.Do(func() { - var errors = []*Error{ - ErrUnknownProtocol, - ErrUnknownNICID, - ErrUnknownDevice, - ErrUnknownProtocolOption, - ErrDuplicateNICID, - ErrDuplicateAddress, - ErrNoRoute, - ErrBadLinkEndpoint, - ErrAlreadyBound, - ErrInvalidEndpointState, - ErrAlreadyConnecting, - ErrAlreadyConnected, - ErrNoPortAvailable, - ErrPortInUse, - ErrBadLocalAddress, - ErrClosedForSend, - ErrClosedForReceive, - ErrWouldBlock, - ErrConnectionRefused, - ErrTimeout, - ErrAborted, - ErrConnectStarted, - ErrDestinationRequired, - ErrNotSupported, - ErrQueueSizeNotSupported, - ErrNotConnected, - ErrConnectionReset, - ErrConnectionAborted, - ErrNoSuchFile, - ErrInvalidOptionValue, - ErrBadAddress, - ErrNetworkUnreachable, - ErrMessageTooLong, - ErrNoBufferSpace, - ErrBroadcastDisabled, - ErrNotPermitted, - ErrAddressFamilyNotSupported, - ErrMalformedHeader, - ErrBadBuffer, - } - - messageToError = make(map[string]*Error) - for _, e := range errors { - if messageToError[e.String()] != nil { - panic("tcpip errors with duplicated message: " + e.String()) - } - messageToError[e.String()] = e - } - }) - - e, ok := messageToError[s] - if !ok { - panic("unknown error message: " + s) - } - - return e -} - // Errors related to Subnet var ( errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") @@ -194,7 +60,7 @@ type ErrSaveRejection struct { } // Error returns a sensible description of the save rejection error. -func (e ErrSaveRejection) Error() string { +func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } @@ -471,30 +337,15 @@ type FullAddress struct { // This interface allows the endpoint to request the amount of data it needs // based on internal buffers without exposing them. type Payloader interface { - // FullPayload returns all available bytes. - FullPayload() ([]byte, *Error) + io.Reader - // Payload returns a slice containing at most size bytes. - Payload(size int) ([]byte, *Error) + // Len returns the number of bytes of the unread portion of the + // Reader. + Len() int } -// SlicePayload implements Payloader for slices. -// -// This is typically used for tests. -type SlicePayload []byte - -// FullPayload implements Payloader.FullPayload. -func (s SlicePayload) FullPayload() ([]byte, *Error) { - return s, nil -} - -// Payload implements Payloader.Payload. -func (s SlicePayload) Payload(size int) ([]byte, *Error) { - if size > len(s) { - size = len(s) - } - return s[:size], nil -} +var _ Payloader = (*bytes.Buffer)(nil) +var _ Payloader = (*bytes.Reader)(nil) var _ io.Writer = (*SliceWriter)(nil) @@ -647,7 +498,7 @@ type Endpoint interface { // If non-zero number of bytes are successfully read and written to dst, err // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer // should be returned. - Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error) + Read(io.Writer, ReadOptions) (ReadResult, Error) // Write writes data to the endpoint's peer. This method does not block if // the data cannot be written. @@ -662,7 +513,7 @@ type Endpoint interface { // stream (TCP) Endpoints may return partial writes, and even then only // in the case where writing additional data would block. Other Endpoints // will either write the entire message or return an error. - Write(Payloader, WriteOptions) (int64, *Error) + Write(Payloader, WriteOptions) (int64, Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -676,21 +527,21 @@ type Endpoint interface { // connected returns nil. Calling connect again results in ErrAlreadyConnected. // Anything else -- the attempt to connect failed. // - // If address.Addr is empty, this means that Enpoint has to be + // If address.Addr is empty, this means that Endpoint has to be // disconnected if this is supported, otherwise // ErrAddressFamilyNotSupported must be returned. - Connect(address FullAddress) *Error + Connect(address FullAddress) Error // Disconnect disconnects the endpoint from its peer. - Disconnect() *Error + Disconnect() Error // Shutdown closes the read and/or write end of the endpoint connection // to its peer. - Shutdown(flags ShutdownFlags) *Error + Shutdown(flags ShutdownFlags) Error // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. - Listen(backlog int) *Error + Listen(backlog int) Error // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. This method does not @@ -700,36 +551,36 @@ type Endpoint interface { // // If peerAddr is not nil then it is populated with the peer address of the // returned endpoint. - Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error) + Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. - Bind(address FullAddress) *Error + Bind(address FullAddress) Error // GetLocalAddress returns the address to which the endpoint is bound. - GetLocalAddress() (FullAddress, *Error) + GetLocalAddress() (FullAddress, Error) // GetRemoteAddress returns the address to which the endpoint is // connected. - GetRemoteAddress() (FullAddress, *Error) + GetRemoteAddress() (FullAddress, Error) // Readiness returns the current readiness of the endpoint. For example, // if waiter.EventIn is set, the endpoint is immediately readable. Readiness(mask waiter.EventMask) waiter.EventMask // SetSockOpt sets a socket option. - SetSockOpt(opt SettableSocketOption) *Error + SetSockOpt(opt SettableSocketOption) Error // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. - SetSockOptInt(opt SockOptInt, v int) *Error + SetSockOptInt(opt SockOptInt, v int) Error // GetSockOpt gets a socket option. - GetSockOpt(opt GettableSocketOption) *Error + GetSockOpt(opt GettableSocketOption) Error // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. - GetSockOptInt(SockOptInt) (int, *Error) + GetSockOptInt(SockOptInt) (int, Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -752,7 +603,7 @@ type Endpoint interface { SetOwner(owner PacketOwner) // LastError clears and returns the last error reported by the endpoint. - LastError() *Error + LastError() Error // SocketOptions returns the structure which contains all the socket // level options. @@ -840,10 +691,6 @@ const ( // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption - // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to - // specify the send buffer size option. - SendBufferSizeOption - // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the receive buffer size option. ReceiveBufferSizeOption @@ -1011,12 +858,54 @@ type SettableSocketOption interface { isSettableSocketOption() } +// CongestionControlState indicates the current congestion control state for +// TCP sender. +type CongestionControlState int + +const ( + // Open indicates that the sender is receiving acks in order and + // no loss or dupACK's etc have been detected. + Open CongestionControlState = iota + // RTORecovery indicates that an RTO has occurred and the sender + // has entered an RTO based recovery phase. + RTORecovery + // FastRecovery indicates that the sender has entered FastRecovery + // based on receiving nDupAck's. This state is entered only when + // SACK is not in use. + FastRecovery + // SACKRecovery indicates that the sender has entered SACK based + // recovery. + SACKRecovery + // Disorder indicates the sender either received some SACK blocks + // or dupACK's. + Disorder +) + // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. type TCPInfoOption struct { - RTT time.Duration + // RTT is the smoothed round trip time. + RTT time.Duration + + // RTTVar is the round trip time variation. RTTVar time.Duration + + // RTO is the retransmission timeout for the endpoint. + RTO time.Duration + + // CcState is the congestion control state. + CcState CongestionControlState + + // SndCwnd is the congestion window, in packets. + SndCwnd uint32 + + // SndSsthresh is the threshold between slow start and congestion + // avoidance. + SndSsthresh uint32 + + // ReorderSeen indicates if reordering is seen in the endpoint. + ReorderSeen bool } func (*TCPInfoOption) isGettableSocketOption() {} @@ -1248,6 +1137,31 @@ type IPPacketInfo struct { DestinationAddr Address } +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + // Min is the minimum size for send buffer. + Min int + + // Default is the default size for send buffer. + Default int + + // Max is the maximum size for send buffer. + Max int +} + +// GetSendBufferLimits is used to get the send buffer size limits. +type GetSendBufferLimits func(StackHandler) SendBufferSizeOption + +// GetStackSendBufferLimits is used to get default, min and max send buffer size. +func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption { + var ss SendBufferSizeOption + if err := so.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + return ss +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -1317,8 +1231,33 @@ func (s *StatCounter) String() string { return strconv.FormatUint(s.Value(), 10) } +// A MultiCounterStat keeps track of two counters at once. +type MultiCounterStat struct { + a, b *StatCounter +} + +// Init sets both internal counters to point to a and b. +func (m *MultiCounterStat) Init(a, b *StatCounter) { + m.a = a + m.b = b +} + +// Increment adds one to the counters. +func (m *MultiCounterStat) Increment() { + m.a.Increment() + m.b.Increment() +} + +// IncrementBy increments the counters by v. +func (m *MultiCounterStat) IncrementBy(v uint64) { + m.a.IncrementBy(v) + m.b.IncrementBy(v) +} + // ICMPv4PacketStats enumerates counts for all ICMPv4 packet types. type ICMPv4PacketStats struct { + // LINT.IfChange(ICMPv4PacketStats) + // Echo is the total number of ICMPv4 echo packets counted. Echo *StatCounter @@ -1358,10 +1297,56 @@ type ICMPv4PacketStats struct { // InfoReply is the total number of ICMPv4 information reply packets // counted. InfoReply *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4PacketStats) +} + +// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats. +type ICMPv4SentPacketStats struct { + // LINT.IfChange(ICMPv4SentPacketStats) + + ICMPv4PacketStats + + // Dropped is the total number of ICMPv4 packets dropped due to link + // layer errors. + Dropped *StatCounter + + // RateLimited is the total number of ICMPv4 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4SentPacketStats) +} + +// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats. +type ICMPv4ReceivedPacketStats struct { + // LINT.IfChange(ICMPv4ReceivedPacketStats) + + ICMPv4PacketStats + + // Invalid is the total number of invalid ICMPv4 packets received. + Invalid *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4ReceivedPacketStats) +} + +// ICMPv4Stats collects ICMPv4-specific stats. +type ICMPv4Stats struct { + // LINT.IfChange(ICMPv4Stats) + + // PacketsSent contains statistics about sent packets. + PacketsSent ICMPv4SentPacketStats + + // PacketsReceived contains statistics about received packets. + PacketsReceived ICMPv4ReceivedPacketStats + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4Stats) } // ICMPv6PacketStats enumerates counts for all ICMPv6 packet types. type ICMPv6PacketStats struct { + // LINT.IfChange(ICMPv6PacketStats) + // EchoRequest is the total number of ICMPv6 echo request packets // counted. EchoRequest *StatCounter @@ -1416,32 +1401,14 @@ type ICMPv6PacketStats struct { // MulticastListenerDone is the total number of Multicast Listener Done // messages counted. MulticastListenerDone *StatCounter -} - -// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats. -type ICMPv4SentPacketStats struct { - ICMPv4PacketStats - - // Dropped is the total number of ICMPv4 packets dropped due to link - // layer errors. - Dropped *StatCounter - // RateLimited is the total number of ICMPv6 packets dropped due to - // rate limit being exceeded. - RateLimited *StatCounter -} - -// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats. -type ICMPv4ReceivedPacketStats struct { - ICMPv4PacketStats - - // Invalid is the total number of ICMPv4 packets received that the - // transport layer could not parse. - Invalid *StatCounter + // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6PacketStats) } // ICMPv6SentPacketStats collects outbound ICMPv6-specific stats. type ICMPv6SentPacketStats struct { + // LINT.IfChange(ICMPv6SentPacketStats) + ICMPv6PacketStats // Dropped is the total number of ICMPv6 packets dropped due to link @@ -1451,47 +1418,41 @@ type ICMPv6SentPacketStats struct { // RateLimited is the total number of ICMPv6 packets dropped due to // rate limit being exceeded. RateLimited *StatCounter + + // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6SentPacketStats) } // ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats. type ICMPv6ReceivedPacketStats struct { + // LINT.IfChange(ICMPv6ReceivedPacketStats) + ICMPv6PacketStats // Unrecognized is the total number of ICMPv6 packets received that the // transport layer does not know how to parse. Unrecognized *StatCounter - // Invalid is the total number of ICMPv6 packets received that the - // transport layer could not parse. + // Invalid is the total number of invalid ICMPv6 packets received. Invalid *StatCounter // RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets // dropped due to being router-specific packets. RouterOnlyPacketsDroppedByHost *StatCounter -} - -// ICMPv4Stats collects ICMPv4-specific stats. -type ICMPv4Stats struct { - // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type - // and a single count of packets which failed to write to the link - // layer. - PacketsSent ICMPv4SentPacketStats - // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4 - // packet type and a single count of invalid packets received. - PacketsReceived ICMPv4ReceivedPacketStats + // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6ReceivedPacketStats) } // ICMPv6Stats collects ICMPv6-specific stats. type ICMPv6Stats struct { - // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type - // and a single count of packets which failed to write to the link - // layer. + // LINT.IfChange(ICMPv6Stats) + + // PacketsSent contains statistics about sent packets. PacketsSent ICMPv6SentPacketStats - // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6 - // packet type and a single count of invalid packets received. + // PacketsReceived contains statistics about received packets. PacketsReceived ICMPv6ReceivedPacketStats + + // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6Stats) } // ICMPStats collects ICMP-specific stats (both v4 and v6). @@ -1505,6 +1466,8 @@ type ICMPStats struct { // IGMPPacketStats enumerates counts for all IGMP packet types. type IGMPPacketStats struct { + // LINT.IfChange(IGMPPacketStats) + // MembershipQuery is the total number of Membership Query messages counted. MembershipQuery *StatCounter @@ -1518,22 +1481,29 @@ type IGMPPacketStats struct { // LeaveGroup is the total number of Leave Group messages counted. LeaveGroup *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPPacketStats) } // IGMPSentPacketStats collects outbound IGMP-specific stats. type IGMPSentPacketStats struct { + // LINT.IfChange(IGMPSentPacketStats) + IGMPPacketStats // Dropped is the total number of IGMP packets dropped. Dropped *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPSentPacketStats) } // IGMPReceivedPacketStats collects inbound IGMP-specific stats. type IGMPReceivedPacketStats struct { + // LINT.IfChange(IGMPReceivedPacketStats) + IGMPPacketStats - // Invalid is the total number of IGMP packets received that IGMP could not - // parse. + // Invalid is the total number of invalid IGMP packets received. Invalid *StatCounter // ChecksumErrors is the total number of IGMP packets dropped due to bad @@ -1543,21 +1513,27 @@ type IGMPReceivedPacketStats struct { // Unrecognized is the total number of unrecognized messages counted, these // are silently ignored for forward-compatibilty. Unrecognized *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPReceivedPacketStats) } -// IGMPStats colelcts IGMP-specific stats. +// IGMPStats collects IGMP-specific stats. type IGMPStats struct { - // IGMPSentPacketStats contains counts of sent packets by IGMP packet type - // and a single count of invalid packets received. + // LINT.IfChange(IGMPStats) + + // PacketsSent contains statistics about sent packets. PacketsSent IGMPSentPacketStats - // IGMPReceivedPacketStats contains counts of received packets by IGMP packet - // type and a single count of invalid packets received. + // PacketsReceived contains statistics about received packets. PacketsReceived IGMPReceivedPacketStats + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats) } // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { + // LINT.IfChange(IPStats) + // PacketsReceived is the total number of IP packets received from the // link layer. PacketsReceived *StatCounter @@ -1575,7 +1551,7 @@ type IPStats struct { InvalidSourceAddressesReceived *StatCounter // PacketsDelivered is the total number of incoming IP packets that - // are successfully delivered to the transport layer via HandlePacket. + // are successfully delivered to the transport layer. PacketsDelivered *StatCounter // PacketsSent is the total number of IP packets sent via WritePacket. @@ -1613,10 +1589,14 @@ type IPStats struct { // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived *StatCounter + + // LINT.ThenChange(network/ip/stats.go:MultiCounterIPStats) } // ARPStats collects ARP-specific stats. type ARPStats struct { + // LINT.IfChange(ARPStats) + // PacketsReceived is the number of ARP packets received from the link layer. PacketsReceived *StatCounter @@ -1644,10 +1624,6 @@ type ARPStats struct { // ARP request with a bad local address. OutgoingRequestBadLocalAddressErrors *StatCounter - // OutgoingRequestNetworkUnreachableErrors is the number of failures to send - // an ARP request with a network unreachable error. - OutgoingRequestNetworkUnreachableErrors *StatCounter - // OutgoingRequestsDropped is the number of ARP requests which failed to write // to a link-layer endpoint. OutgoingRequestsDropped *StatCounter @@ -1666,6 +1642,8 @@ type ARPStats struct { // OutgoingRepliesSent is the number of ARP replies successfully written to a // link-layer endpoint. OutgoingRepliesSent *StatCounter + + // LINT.ThenChange(network/arp/stats.go:multiCounterARPStats) } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 1742a178d..71695b630 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -7,6 +7,7 @@ go_test( size = "small", srcs = [ "forward_test.go", + "iptables_test.go", "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", @@ -16,6 +17,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index ac9670f9a..38e1881c7 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -38,96 +38,207 @@ import ( var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil) var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil) -// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner -// link endpoint and checks the destination link address before delivering -// network packets to the network dispatcher. -// -// See ethernet.Endpoint for more details. -func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { - var e endpointWithDestinationCheck - e.Endpoint.Init(ethernet.New(ep), &e) - return &e -} - -// endpointWithDestinationCheck is a link endpoint that checks the destination -// link address before delivering network packets to the network dispatcher. -type endpointWithDestinationCheck struct { - nested.Endpoint -} - -// DeliverNetworkPacket implements stack.NetworkDispatcher. -func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) - } -} - -func TestForwarding(t *testing.T) { - const ( - host1NICID = 1 - routerNICID1 = 2 - routerNICID2 = 3 - host2NICID = 4 - - listenPort = 8080 - ) +const ( + host1NICID = 1 + routerNICID1 = 2 + routerNICID2 = 3 + host2NICID = 4 +) - host1IPv4Addr := tcpip.ProtocolAddress{ +var ( + host1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 24, }, } - routerNIC1IPv4Addr := tcpip.ProtocolAddress{ + routerNIC1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - routerNIC2IPv4Addr := tcpip.ProtocolAddress{ + routerNIC2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, }, } - host2IPv4Addr := tcpip.ProtocolAddress{ + host2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr := tcpip.ProtocolAddress{ + host1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), PrefixLen: 64, }, } - routerNIC1IPv6Addr := tcpip.ProtocolAddress{ + routerNIC1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - routerNIC2IPv6Addr := tcpip.ProtocolAddress{ + routerNIC2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr := tcpip.ProtocolAddress{ + host2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::2").To16()), PrefixLen: 64, }, } +) + +func setupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) { + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + { + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + { + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) +} + +// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner +// link endpoint and checks the destination link address before delivering +// network packets to the network dispatcher. +// +// See ethernet.Endpoint for more details. +func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { + var e endpointWithDestinationCheck + e.Endpoint.Init(ethernet.New(ep), &e) + return &e +} + +// endpointWithDestinationCheck is a link endpoint that checks the destination +// link address before delivering network packets to the network dispatcher. +type endpointWithDestinationCheck struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) + } +} + +func TestForwarding(t *testing.T) { + const listenPort = 8080 type endpointAndAddresses struct { serverEP tcpip.Endpoint @@ -229,7 +340,7 @@ func TestForwarding(t *testing.T) { subTests := []struct { name string proto tcpip.TransportProtocolNumber - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) needRemoteAddr bool }{ @@ -250,7 +361,7 @@ func TestForwarding(t *testing.T) { { name: "TCP", proto: tcp.ProtocolNumber, - expectedConnectErr: tcpip.ErrConnectStarted, + expectedConnectErr: &tcpip.ErrConnectStarted{}, setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() @@ -260,7 +371,7 @@ func TestForwarding(t *testing.T) { var addr tcpip.FullAddress for { newEP, wq, err := ep.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-ch continue } @@ -294,113 +405,7 @@ func TestForwarding(t *testing.T) { host1Stack := stack.New(stackOpts) routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - - host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) - routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) - } - if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - { - Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - }) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) defer epsAndAddrs.serverEP.Close() @@ -415,8 +420,11 @@ func TestForwarding(t *testing.T) { t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) } - if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr { - t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr) + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } } if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) @@ -436,9 +444,10 @@ func TestForwarding(t *testing.T) { write := func(ep tcpip.Endpoint, data []byte) { t.Helper() - dataPayload := tcpip.SlicePayload(data) + var r bytes.Reader + r.Reset(data) var wOpts tcpip.WriteOptions - n, err := ep.Write(dataPayload, wOpts) + n, err := ep.Write(&r, wOpts) if err != nil { t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) } @@ -486,7 +495,7 @@ func TestForwarding(t *testing.T) { read(serverCH, serverEP, data, clientAddr) - data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + data = []byte{5, 6, 7, 8, 9, 10, 11, 12} write(serverEP, data) read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) }) diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go new file mode 100644 index 000000000..21a8dd291 --- /dev/null +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -0,0 +1,336 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type inputIfNameMatcher struct { + name string +} + +var _ stack.Matcher = (*inputIfNameMatcher)(nil) + +func (*inputIfNameMatcher) Name() string { + return "inputIfNameMatcher" +} + +func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { + return (hook == stack.Input && im.name != "" && im.name == inNicName), false +} + +const ( + nicID = 1 + nicName = "nic1" + anotherNicName = "nic2" + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + srcAddrV4 = "\x0a\x00\x00\x01" + dstAddrV4 = "\x0a\x00\x00\x02" + srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + payloadSize = 20 +) + +func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + }) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + } + return s, e +} + +func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + e := channel.New(0, header.IPv4MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + } + return s, e +} + +func genPacketV6() *stack.PacketBuffer { + pktSize := header.IPv6MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv6(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadSize, + TransportProtocol: 99, + HopLimit: 255, + SrcAddr: srcAddrV6, + DstAddr: dstAddrV6, + }) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func genPacketV4() *stack.PacketBuffer { + pktSize := header.IPv4MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv4Fields{ + TOS: 0, + TotalLength: uint16(pktSize), + ID: 1, + Flags: 0, + FragmentOffset: 16, + TTL: 48, + Protocol: 99, + SrcAddr: srcAddrV4, + DstAddr: dstAddrV4, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func TestIPTablesStatsForInput(t *testing.T) { + tests := []struct { + name string + setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) + setupFilter func(*testing.T, *stack.Stack) + genPacket func() *stack.PacketBuffer + proto tcpip.NetworkProtocolNumber + expectReceived int + expectInputDropped int + }{ + { + name: "IPv6 Accept", + setupStack: genStackV6, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept", + setupStack: genStackV4, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface matches)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface matches)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface does not match but invert is true)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface does not match but invert is true)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match using a matcher)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match using a matcher)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := test.setupStack(t) + test.setupFilter(t, s) + e.InjectInbound(test.proto, test.genPacket()) + + if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { + t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) + } + if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { + t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index af32d3009..7069352f2 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -19,11 +19,14 @@ import ( "fmt" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -32,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -207,8 +211,10 @@ func TestPing(t *testing.T) { defer ep.Close() icmpBuf := test.icmpBuf(t) + var r bytes.Reader + r.Reset(icmpBuf) wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(_, _): %s", err) } else if want := int64(len(icmpBuf)); n != want { t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) @@ -251,7 +257,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name string netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error sockError tcpip.SockError }{ { @@ -270,9 +276,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name: "IPv4 without resolvable remote", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, + Err: &tcpip.ErrNoRoute{}, ErrType: byte(header.ICMPv4DstUnreachable), ErrCode: byte(header.ICMPv4HostUnreachable), ErrOrigin: tcpip.SockExtErrorOriginICMP, @@ -292,9 +298,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name: "IPv6 without resolvable remote", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, + Err: &tcpip.ErrNoRoute{}, ErrType: byte(header.ICMPv6DstUnreachable), ErrCode: byte(header.ICMPv6AddressUnreachable), ErrOrigin: tcpip.SockExtErrorOriginICMP6, @@ -351,16 +357,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr := listenerAddr remoteAddr.Addr = test.remoteAddr - if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted) + { + err := clientEP.Connect(remoteAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{}) + } } // Wait for an error due to link resolution failing, or the endpoint to be // writable. <-ch - var wOpts tcpip.WriteOptions - if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr { - t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + _, err := clientEP.Write(&r, wOpts) + if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" { + t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff) + } } if test.expectedWriteErr == nil { @@ -374,7 +388,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { sockErrCmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(tcpip.SockError{}), - cmp.Comparer(func(a, b *tcpip.Error) bool { + cmp.Comparer(func(a, b tcpip.Error) bool { // tcpip.Error holds an unexported field but the errors netstack uses // are pre defined so we can simply compare pointers. return a == b @@ -404,20 +418,134 @@ func TestGetLinkAddress(t *testing.T) { ) tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedLinkAddr bool + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedOk bool }{ { - name: "IPv4", + name: "IPv4 resolvable", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedOk: true, }, { - name: "IPv6", + name: "IPv6 resolvable", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedOk: true, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + expectedOk: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + expectedOk: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, useNeighborCache := range []bool{true, false} { + t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + UseNeighborCache: useNeighborCache, + } + + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + wantRes := stack.LinkResolutionResult{Success: test.expectedOk} + if test.expectedOk { + wantRes.LinkAddress = linkAddr2 + } + if diff := cmp.Diff(wantRes, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) + } + }) + } +} + +func TestRouteResolvedFields(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + localAddr tcpip.Address + remoteAddr tcpip.Address + immediatelyResolvable bool + expectedSuccess bool + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4 immediately resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: header.IPv4AllSystems, + immediatelyResolvable: true, + expectedSuccess: true, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems), + }, + { + name: "IPv6 immediately resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: header.IPv6AllNodesMulticastAddress, + immediatelyResolvable: true, + expectedSuccess: true, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, + { + name: "IPv4 resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: true, + expectedLinkAddr: linkAddr2, + }, + { + name: "IPv6 resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: true, + expectedLinkAddr: linkAddr2, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: false, }, } @@ -431,28 +559,618 @@ func TestGetLinkAddress(t *testing.T) { } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + var wantRouteInfo stack.RouteInfo + wantRouteInfo.LocalLinkAddress = linkAddr1 + wantRouteInfo.LocalAddress = test.localAddr + wantRouteInfo.RemoteAddress = test.remoteAddr + wantRouteInfo.NetProto = test.netProto + wantRouteInfo.Loop = stack.PacketOut + wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr + + ch := make(chan stack.ResolvedFieldsResult, 1) + + if !test.immediatelyResolvable { + wantUnresolvedRouteInfo := wantRouteInfo + wantUnresolvedRouteInfo.RemoteLinkAddress = "" - for i := 0; i < 2; i++ { - addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {}) - var want *tcpip.Error - if i == 0 { - want = tcpip.ErrWouldBlock + err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } - if err != want { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want) + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) } - if i == 0 { - <-ch - continue + if !test.expectedSuccess { + return } - if addr != linkAddr2 { - t.Fatalf("got addr = %s, want = %s", addr, linkAddr2) + // At this point the neighbor table should be populated so the route + // should be immediately resolvable. + } + + if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }); err != nil { + t.Errorf("r.ResolvedFields(_): %s", err) + } + select { + case routeResolveRes := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected route to be immediately resolvable") } }) } }) } } + +func TestWritePacketsLinkResolution(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedWriteErr tcpip.Error + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + } + + host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) + } + defer serverEP.Close() + + serverAddr := tcpip.FullAddress{Port: 1234} + if err := serverEP.Bind(serverAddr); err != nil { + t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err) + } + + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + data := []byte{1, 2} + var pkts stack.PacketBufferList + for _, d := range data { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: buffer.View([]byte{d}).ToVectorisedView(), + }) + pkt.TransportProtocolNumber = udp.ProtocolNumber + length := uint16(pkt.Size()) + udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: serverAddr.Port, + Length: length, + }) + xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + + pkts.PushBack(pkt) + } + + params := stack.NetworkHeaderParams{ + Protocol: udp.ProtocolNumber, + TTL: 64, + TOS: stack.DefaultTOS, + } + + if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil { + t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err) + } else if want := pkts.Len(); want != n { + t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want) + } + + var writer bytes.Buffer + count := 0 + for { + var rOpts tcpip.ReadOptions + res, err := serverEP.Read(&writer, rOpts) + if err != nil { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + // Should not have anymore bytes to read after we read the sent + // number of bytes. + if count == len(data) { + break + } + + <-serverCH + continue + } + + t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err) + } + count += res.Count + } + + if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want { + t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if diff := cmp.Diff(data, writer.Bytes()); diff != "" { + t.Errorf("read bytes mismatch (-want +got):\n%s", diff) + } + }) + } +} + +type eventType int + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved +) + +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + entry stack.NeighborEntry +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) +} + +var _ stack.NUDDispatcher = (*nudDispatcher)(nil) + +type nudDispatcher struct { + c chan eventInfo +} + +func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) waitForEvent(want eventInfo) error { + if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil +} + +// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it +// that the neighbor used for a route is reachable. +func TestTCPConfirmNeighborReachability(t *testing.T) { + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + neighborAddr tcpip.Address + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + isHost1Listener bool + }{ + { + name: "IPv4 active connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 active connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 passive connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv4 passive connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + nudDisp := nudDispatcher{ + c: make(chan eventInfo, 3), + } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, + UseNeighborCache: true, + } + host1StackOpts := stackOpts + host1StackOpts.NUDDisp = &nudDisp + + host1Stack := stack.New(host1StackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + // Add a reachable dynamic entry to our neighbor table for the remote. + { + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: linkAddr2, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryAdded, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, + }); err != nil { + t.Fatalf("error waiting for initial NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the remote's neighbor entry to be stale before creating a + // TCP connection from host1 to some remote. + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err) + } + // The maximum reachable time for a neighbor is some maximum random factor + // applied to the base reachable time. + // + // See NUDConfigurations.BaseReachableTime for more information. + maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + + listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) + defer listenerEP.Close() + defer clientEP.Close() + listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} + if err := listenerEP.Bind(listenerAddr); err != nil { + t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) + } + if err := listenerEP.Listen(1); err != nil { + t.Fatalf("listenerEP.Listen(1): %s", err) + } + { + err := clientEP.Connect(listenerAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{}) + } + } + + // Wait for the TCP handshake to complete then make sure the neighbor is + // reachable without entering the probe state as TCP should provide NUD + // with confirmation that the neighbor is reachable (indicated by a + // successful 3-way handshake). + <-clientCH + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the neighbor to be stale again then send data to the remote. + // + // On successful transmission, the neighbor should become reachable + // without probing the neighbor as a TCP ACK would be received which is an + // indication of the neighbor being reachable. + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if test.isHost1Listener { + // If host1 is not the client, host1 does not send any data so TCP + // has no way to know it is making forward progress. Because of this, + // TCP should not mark the route reachable and NUD should go through the + // probe state. + clock.Advance(nudConfigs.DelayFirstProbeTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for probe NUD event: %s", err) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3b13ba04d..ab67762ef 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -37,7 +37,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) type ndpDispatcher struct{} -func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { } func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { @@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { Port: localPort, }, } - n, err := sep.Write(tcpip.SlicePayload(data), wopts) + var r bytes.Reader + r.Reset(data) + n, err := sep.Write(&r, wopts) if err != nil { t.Fatalf("sep.Write(_, _): %s", err) } @@ -260,8 +262,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -320,11 +322,14 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) } - if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + { + err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) + } } } @@ -468,13 +473,17 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { Addr: test.dstAddr, Port: localPort, } - if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + { + err := connectingEndpoint.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + } } if !test.expectAccept { - if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + _, _, err := listeningEndpoint.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } return } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index ce7c16bd1..d685fdd36 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -479,8 +479,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) - if n, err := wep.ep.Write(data, writeOpts); err != nil { + data := []byte{byte(i), 2, 3, 4} + var r bytes.Reader + r.Reset(data) + if n, err := wep.ep.Write(&r, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) @@ -759,8 +761,11 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { if err := ep.SetSockOpt(&removeOpt); err != nil { t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) } - if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) + { + _, err := ep.Read(&buf, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{}) + } } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index b222d2b05..9654c9527 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -81,7 +81,7 @@ func TestLocalPing(t *testing.T) { linkEndpoint func() stack.LinkEndpoint localAddr tcpip.Address icmpBuf func(*testing.T) buffer.View - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) }{ { @@ -126,7 +126,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -135,7 +135,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -144,7 +144,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, { @@ -153,7 +153,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, } @@ -186,17 +186,22 @@ func TestLocalPing(t *testing.T) { defer ep.Close() connAddr := tcpip.FullAddress{Addr: test.localAddr} - if err := ep.Connect(connAddr); err != test.expectedConnectErr { - t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + { + err := ep.Connect(connAddr) + if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) + } } if test.expectedConnectErr != nil { return } - payload := tcpip.SlicePayload(test.icmpBuf(t)) + payload := test.icmpBuf(t) + var r bytes.Reader + r.Reset(payload) var wOpts tcpip.WriteOptions - if n, err := ep.Write(payload, wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) } else if n != int64(len(payload)) { t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) @@ -261,12 +266,12 @@ func TestLocalUDP(t *testing.T) { subTests := []struct { name string addAddress bool - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error }{ { name: "Unassigned local address", addAddress: false, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, }, { name: "Assigned local address", @@ -329,12 +334,14 @@ func TestLocalUDP(t *testing.T) { Port: 80, } - clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + clientPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(clientPayload) wOpts := tcpip.WriteOptions{ To: &serverAddr, } - if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr { t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) } else if subTest.expectedWriteErr != nil { // Nothing else to test if we expected not to be able to send the @@ -376,12 +383,14 @@ func TestLocalUDP(t *testing.T) { } } - serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + serverPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(serverPayload) wOpts := tcpip.WriteOptions{ To: &clientAddr, } - if n, err := server.Write(serverPayload, wOpts); err != nil { + if n, err := server.Write(&r, wOpts); err != nil { t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) } else if n != int64(len(serverPayload)) { t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 256e19296..3cf05520d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -69,8 +69,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int + mu sync.RWMutex `state:"nosave"` // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState @@ -85,7 +84,7 @@ type endpoint struct { ops tcpip.SocketOptions } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -94,11 +93,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.SetSendBufferSize(32*1024, false /* notify */) + + // Override with stack defaults. + var ss tcpip.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) + } return ep, nil } @@ -119,7 +124,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -154,14 +159,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -188,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil @@ -199,7 +204,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.state { case stateInitial: case stateConnected: @@ -207,11 +212,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi case stateBound: if to == nil { - return false, tcpip.ErrDestinationRequired + return false, &tcpip.ErrDestinationRequired{} } return false, nil default: - return false, tcpip.ErrInvalidEndpointState + return false, &tcpip.ErrInvalidEndpointState{} } e.mu.RUnlock() @@ -236,18 +241,18 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -257,10 +262,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } to := opts.To @@ -270,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } // Prepare for write. @@ -292,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc nicID := to.NIC if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } nicID = e.BindNICID @@ -313,11 +318,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc route = r } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } + var err tcpip.Error switch e.NetProto { case header.IPv4ProtocolNumber: err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) @@ -334,12 +340,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.TTLOption: e.mu.Lock() @@ -351,7 +357,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -362,11 +368,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSize - e.mu.Unlock() - return v, nil case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() @@ -381,18 +382,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -410,7 +411,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi // Linux performs these basic checks. if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } icmpv4.SetChecksum(0) @@ -424,9 +425,9 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -441,7 +442,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } dataVV := data.ToVectorisedView() @@ -456,7 +457,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { return tcpip.FullAddress{}, 0, err @@ -465,12 +466,12 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect connects the endpoint to its peer. Specifying a NIC is optional. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -485,12 +486,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } if nicID != 0 && nicID != e.BindNICID { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nicID = e.BindNICID default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -535,19 +536,19 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection // to its peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() e.shutdownFlags |= flags if e.state != stateConnected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } if flags&tcpip.ShutdownRead != 0 { @@ -565,31 +566,31 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen is not supported by UDP, it just fails. -func (*endpoint) Listen(int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) - switch err { + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) + switch err.(type) { case nil: return true, nil - case tcpip.ErrPortInUse: + case *tcpip.ErrPortInUse: return false, nil default: return false, err @@ -599,11 +600,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ return id, err } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. if e.state != stateInitial { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -619,7 +620,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { if len(addr.Addr) != 0 { // A local address was specified, verify that it's valid. if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } @@ -647,7 +648,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -663,7 +664,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -675,12 +676,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() if e.state != stateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return tcpip.FullAddress{ @@ -805,7 +806,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { func (*endpoint) Wait() {} // LastError implements tcpip.Endpoint.LastError. -func (*endpoint) LastError() *tcpip.Error { +func (*endpoint) LastError() tcpip.Error { return nil } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 9d263c0ec..c9fa9974a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -69,12 +69,13 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) if e.state != stateBound && e.state != stateConnected { return } - var err *tcpip.Error + var err tcpip.Error if e.state == stateConnected { e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) if err != nil { @@ -84,7 +85,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.ID.LocalAddress = e.route.LocalAddress } else if len(e.ID.LocalAddress) != 0 { // stateBound if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3820e5dc7..47f7dd1cb 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -59,18 +59,18 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if netProto != p.netProto() { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if netProto != p.netProto() { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } @@ -87,7 +87,7 @@ func (p *protocol) MinimumPacketSize() int { } // ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil. -func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { switch p.number { case ProtocolNumber4: hdr := header.ICMPv4(v) @@ -106,13 +106,13 @@ func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stac } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c0d6fb442..73bb66830 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -79,24 +79,22 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool - boundNIC tcpip.NICID + mu sync.RWMutex `state:"nosave"` + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID // lastErrorMu protects lastError. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // ops is used to get socket level options. ops tcpip.SocketOptions } // NewEndpoint returns a new packet endpoint. -func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -106,14 +104,13 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb netProto: netProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - ep.sndBufSizeMax = ss.Default + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -162,16 +159,16 @@ func (ep *endpoint) Close() { func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. if ep.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if ep.rcvClosed { ep.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } ep.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -201,49 +198,49 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul n, err := packet.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be // disconnected, and this function always returns tpcip.ErrNotSupported. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be -// connected, and this function always returnes tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return tcpip.ErrNotSupported +// connected, and this function always returnes *tcpip.ErrNotSupported. +func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used -// with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - return tcpip.ErrNotSupported +// with Shutdown, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with -// Listen, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrNotSupported +// Listen, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Listen(backlog int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with -// Accept, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +// Accept, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } // Bind implements tcpip.Endpoint.Bind. -func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { // TODO(gvisor.dev/issue/173): Add Bind support. // "By default, all packets of the specified protocol type are passed @@ -277,14 +274,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, tcpip.ErrNotSupported +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { // Even a connected socket doesn't return a remote address. - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness implements tcpip.Endpoint.Readiness. @@ -306,38 +303,20 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be // used with SetSockOpt, and this function always returns -// tcpip.ErrNotSupported. -func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +// *tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := ep.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - ep.mu.Lock() - ep.sndBufSizeMax = v - ep.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -357,11 +336,11 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (ep *endpoint) LastError() *tcpip.Error { +func (ep *endpoint) LastError() tcpip.Error { ep.lastErrorMu.Lock() defer ep.lastErrorMu.Unlock() @@ -371,19 +350,19 @@ func (ep *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (ep *endpoint) UpdateLastError(err *tcpip.Error) { +func (ep *endpoint) UpdateLastError(err tcpip.Error) { ep.lastErrorMu.Lock() ep.lastError = err ep.lastErrorMu.Unlock() } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrNotSupported +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrNotSupported{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -395,12 +374,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - ep.mu.Lock() - v := ep.sndBufSizeMax - ep.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: ep.rcvMu.Lock() v := ep.rcvBufSizeMax @@ -408,7 +381,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index e2fa96d17..ece662c0d 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -63,29 +63,11 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. ep.stack = stack.StackFromEnv + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(*err) + panic(err) } } - -// saveLastError is invoked by stateify. -func (ep *endpoint) saveLastError() string { - if ep.lastError == nil { - return "" - } - - return ep.lastError.String() -} - -// loadLastError is invoked by stateify. -func (ep *endpoint) loadLastError(s string) { - if s == "" { - return - } - - ep.lastError = tcpip.StringToError(s) -} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ae743f75e..9c9ccc0ff 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -76,12 +76,10 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - connected bool - bound bool + mu sync.RWMutex `state:"nosave"` + closed bool + connected bool + bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. route *stack.Route `state:"manual"` @@ -95,13 +93,13 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } e := &endpoint{ @@ -112,16 +110,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, associated: associated, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) e.ops.SetHeaderIncluded(!associated) + e.ops.SetSendBufferSize(32*1024, false /* notify */) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -138,7 +136,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } - if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { return nil, err } @@ -159,7 +157,7 @@ func (e *endpoint) Close() { return } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.rcvMu.Lock() defer e.rcvMu.Unlock() @@ -191,16 +189,16 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -227,37 +225,37 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil } // Write implements tcpip.Endpoint.Write. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // We can create, but not write to, unassociated IPv6 endpoints. if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } } n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -267,22 +265,22 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } e.mu.RLock() defer e.mu.RUnlock() if e.closed { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } - payloadBytes, err := p.FullPayload() - if err != nil { - return 0, err + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, &tcpip.ErrBadBuffer{} } // If this is an unassociated socket and callee provided a nonzero @@ -290,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } dstAddr := ip.DestinationAddress() // Update dstAddr with the address in the IP header, unless @@ -311,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - return 0, tcpip.ErrDestinationRequired + return 0, &tcpip.ErrDestinationRequired{} } return e.finishWrite(payloadBytes, e.route) @@ -321,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } // Find the route to the destination. If BindAddress is 0, @@ -338,7 +336,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) { +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, tcpip.Error) { if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), @@ -365,22 +363,22 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect implements tcpip.Endpoint.Connect. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { - return tcpip.ErrAddressFamilyNotSupported + return &tcpip.ErrAddressFamilyNotSupported{} } e.mu.Lock() defer e.mu.Unlock() if e.closed { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nic := addr.NIC @@ -395,7 +393,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } else if addr.NIC != e.BindNICID { // We're bound and addr specifies a NIC. They must be // the same. - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } } @@ -407,15 +405,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { route.Release() return err } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.RegisterNICID = nic } - // Save the route we've connected via. + if e.route != nil { + // If the endpoint was previously connected then release any previous route. + e.route.Release() + } e.route = route e.connected = true @@ -423,42 +424,42 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() if !e.connected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } return nil } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(backlog int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept implements tcpip.Endpoint.Accept. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } // Bind implements tcpip.Endpoint.Bind. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If a local address was specified, verify that it's valid. - if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { - return tcpip.ErrBadLocalAddress + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 { + return &tcpip.ErrBadLocalAddress{} } if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { return err } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.RegisterNICID = addr.NIC e.BindNICID = addr.NIC } @@ -470,14 +471,14 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, tcpip.ErrNotSupported +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { // Even a connected socket doesn't return a remote address. - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness implements tcpip.Endpoint.Readiness. @@ -498,37 +499,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -548,17 +531,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -570,12 +553,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax @@ -583,7 +560,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } @@ -703,7 +680,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { func (*endpoint) Wait() {} // LastError implements tcpip.Endpoint.LastError. -func (*endpoint) LastError() *tcpip.Error { +func (*endpoint) LastError() tcpip.Error { return nil } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 4a7e1c039..263ec5146 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -69,10 +69,11 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) // If the endpoint is connected, re-connect. if e.connected { - var err *tcpip.Error + var err tcpip.Error // TODO(gvisor.dev/issue/4906): Properly restore the route with the right // remote address. We used to pass e.remote.RemoteAddress which was // effectively the empty address but since moving e.route to hold a pointer @@ -88,12 +89,12 @@ func (e *endpoint) Resume(s *stack.Stack) { // If the endpoint is bound, re-bind. if e.bound { if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } if e.associated { - if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index f30aa2a4a..e393b993d 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -25,11 +25,11 @@ import ( type EndpointFactory struct{} // NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. -func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } // NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. -func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7e81203ba..fcdd032c5 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -99,7 +99,6 @@ go_test( "//pkg/rand", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6921de0f1..842c1622b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,7 +199,7 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // On success, a handshake h is returned with h.ep.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { +func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -267,7 +267,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.mu.Unlock() ep.Close() - return nil, tcpip.ErrConnectionAborted + return nil, &tcpip.ErrConnectionAborted{} } l.addPendingEndpoint(ep) @@ -281,14 +281,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q l.removePendingEndpoint(ep) - return nil, tcpip.ErrConnectionAborted + return nil, &tcpip.ErrConnectionAborted{} } deferAccept = l.listenEP.deferAccept } // Register new endpoint so that packets are routed to it. - if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { ep.mu.Unlock() ep.Close() @@ -313,7 +313,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // established endpoint is returned with e.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { +func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) { h, err := l.startHandshake(s, opts, queue, owner) if err != nil { return nil, err @@ -467,7 +467,7 @@ func (e *endpoint) notifyAborted() { // cookies to accept connections. // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error { defer s.decRef() h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) @@ -522,7 +522,7 @@ func (e *endpoint) acceptQueueIsFull() bool { // and needs to handle it. // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() @@ -692,7 +692,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er } // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { n.mu.Unlock() n.Close() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index f711cd4df..4695b66d6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -226,7 +226,7 @@ func (h *handshake) checkAck(s *segment) bool { // synSentState handles a segment received when the TCP 3-way handshake is in // the SYN-SENT state. -func (h *handshake) synSentState(s *segment) *tcpip.Error { +func (h *handshake) synSentState(s *segment) tcpip.Error { // RFC 793, page 37, states that in the SYN-SENT state, a reset is // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { @@ -237,7 +237,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.ep.workerCleanup = true // Although the RFC above calls out ECONNRESET, Linux actually returns // ECONNREFUSED here so we do as well. - return tcpip.ErrConnectionRefused + return &tcpip.ErrConnectionRefused{} } return nil } @@ -314,12 +314,12 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // synRcvdState handles a segment received when the TCP 3-way handshake is in // the SYN-RCVD state. -func (h *handshake) synRcvdState(s *segment) *tcpip.Error { +func (h *handshake) synRcvdState(s *segment) tcpip.Error { if s.flagIsSet(header.TCPFlagRst) { // RFC 793, page 37, states that in the SYN-RCVD state, a reset // is acceptable if the sequence number is in the window. if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { - return tcpip.ErrConnectionRefused + return &tcpip.ErrConnectionRefused{} } return nil } @@ -349,7 +349,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) if !h.active { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } h.resetState() @@ -412,7 +412,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } -func (h *handshake) handleSegment(s *segment) *tcpip.Error { +func (h *handshake) handleSegment(s *segment) tcpip.Error { h.sndWnd = s.window if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) @@ -429,7 +429,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error { // processSegments goes through the segment queue and processes up to // maxSegmentsPerWake (if they're available). -func (h *handshake) processSegments() *tcpip.Error { +func (h *handshake) processSegments() tcpip.Error { for i := 0; i < maxSegmentsPerWake; i++ { s := h.ep.segmentQueue.dequeue() if s == nil { @@ -505,7 +505,7 @@ func (h *handshake) start() { } // complete completes the TCP 3-way handshake initiated by h.start(). -func (h *handshake) complete() *tcpip.Error { +func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper resendWaker := sleep.Waker{} @@ -555,7 +555,7 @@ func (h *handshake) complete() *tcpip.Error { case wakerForNotification: n := h.ep.fetchNotifications() if (n¬ifyClose)|(n¬ifyAbort) != 0 { - return tcpip.ErrAborted + return &tcpip.ErrAborted{} } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { @@ -593,19 +593,19 @@ type backoffTimer struct { t *time.Timer } -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { if timeout > maxTimeout { - return nil, tcpip.ErrTimeout + return nil, &tcpip.ErrTimeout{} } bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} bt.t = time.AfterFunc(timeout, f) return bt, nil } -func (bt *backoffTimer) reset() *tcpip.Error { +func (bt *backoffTimer) reset() tcpip.Error { bt.timeout *= 2 if bt.timeout > MaxRTO { - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } bt.t.Reset(bt.timeout) return nil @@ -706,7 +706,7 @@ type tcpFields struct { txHash uint32 } -func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error { tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { @@ -716,7 +716,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp return nil } -func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error { tf.txHash = e.txHash if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() @@ -755,7 +755,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { // We need to shallow clone the VectorisedView here as ReadToView will // split the VectorisedView and Trim underlying views as it splits. Not // doing the clone here will cause the underlying views of data itself @@ -803,7 +803,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > math.MaxUint16 { tf.rcvWnd = math.MaxUint16 @@ -875,7 +875,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendRaw sends a TCP segment to the endpoint's peer. -func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] @@ -895,55 +895,60 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn return err } -func (e *endpoint) handleWrite() *tcpip.Error { - // Move packets from send queue to send list. The queue is accessible - // from other goroutines and protected by the send mutex, while the send - // list is only accessible from the handler goroutine, so it needs no - // mutexes. +func (e *endpoint) handleWrite() { e.sndBufMu.Lock() + next := e.drainSendQueueLocked() + e.sndBufMu.Unlock() + + e.sendData(next) +} +// Move packets from send queue to send list. +// +// Precondition: e.sndBufMu must be locked. +func (e *endpoint) drainSendQueueLocked() *segment { first := e.sndQueue.Front() if first != nil { e.snd.writeList.PushBackList(&e.sndQueue) e.sndBufInQueue = 0 } + return first +} - e.sndBufMu.Unlock() - +// Precondition: e.mu must be locked. +func (e *endpoint) sendData(next *segment) { // Initialize the next segment to write if it's currently nil. if e.snd.writeNext == nil { - e.snd.writeNext = first + e.snd.writeNext = next } // Push out any new packets. e.snd.sendData() - - return nil } -func (e *endpoint) handleClose() *tcpip.Error { +func (e *endpoint) handleClose() { if !e.EndpointState().connected() { - return nil + return } // Drain the send queue. e.handleWrite() // Mark send side as closed. e.snd.closed = true - - return nil } // resetConnectionLocked puts the endpoint in an error state with the given // error code and sends a RST if and only if the error is not ErrConnectionReset // indicating that the connection is being reset due to receiving a RST. This // method must only be called from the protocol goroutine. -func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { +func (e *endpoint) resetConnectionLocked(err tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.setEndpointState(StateError) e.hardError = err - if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + switch err.(type) { + case *tcpip.ErrConnectionReset, *tcpip.ErrTimeout: + default: // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk // which can cause sndNxt to be outside the acceptable window on the @@ -1053,7 +1058,7 @@ func (e *endpoint) drainClosingSegmentQueue() { } } -func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { +func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states // except SYN-SENT, all reset (RST) segments are @@ -1081,7 +1086,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // delete the TCB, and return. case StateCloseWait: e.transitionToStateCloseLocked() - e.hardError = tcpip.ErrAborted + e.hardError = &tcpip.ErrAborted{} e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: @@ -1094,14 +1099,14 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // handleSegment is invoked from the processor goroutine // rather than the worker goroutine. e.notifyProtocolGoroutine(notifyResetByPeer) - return false, tcpip.ErrConnectionReset + return false, &tcpip.ErrConnectionReset{} } } return true, nil } // handleSegments processes all inbound segments. -func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { +func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { if e.EndpointState().closed() { @@ -1148,7 +1153,7 @@ func (e *endpoint) probeSegment() { // handleSegment handles a given segment and notifies the worker goroutine if // if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { +func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { // Invoke the tcp probe if installed. The tcp probe function will update // the TCPEndpointState after the segment is processed. defer e.probeSegment() @@ -1222,7 +1227,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. -func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { +func (e *endpoint) keepaliveTimerExpired() tcpip.Error { userTimeout := e.userTimeout e.keepalive.Lock() @@ -1236,13 +1241,13 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with @@ -1286,7 +1291,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { e.mu.Lock() var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1332,6 +1337,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } + // Reaching this point means that we successfully completed the 3-way + // handshake with our peer. + // + // Completing the 3-way handshake is an indication that the route is valid + // and the remote is reachable as the only way we can complete a handshake + // is if our SYN reached the remote and their ACK reached us. + e.route.ConfirmReachable() + drained := e.drainDone != nil if drained { close(e.drainDone) @@ -1344,19 +1357,25 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // wakes up. funcs := []struct { w *sleep.Waker - f func() *tcpip.Error + f func() tcpip.Error }{ { w: &e.sndWaker, - f: e.handleWrite, + f: func() tcpip.Error { + e.handleWrite() + return nil + }, }, { w: &e.sndCloseWaker, - f: e.handleClose, + f: func() tcpip.Error { + e.handleClose() + return nil + }, }, { w: &closeWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { // This means the socket is being closed due // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. @@ -1367,10 +1386,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.snd.resendWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { if !e.snd.retransmitTimerExpired() { e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } return nil }, @@ -1381,7 +1400,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.newSegmentWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { return e.handleSegments(false /* fastPath */) }, }, @@ -1391,7 +1410,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.notificationWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { n := e.fetchNotifications() if n¬ifyNonZeroReceiveWindow != 0 { e.rcv.nonZeroWindow() @@ -1408,11 +1427,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyReset != 0 || n¬ifyAbort != 0 { - return tcpip.ErrConnectionAborted + return &tcpip.ErrConnectionAborted{} } if n¬ifyResetByPeer != 0 { - return tcpip.ErrConnectionReset + return &tcpip.ErrConnectionReset{} } if n¬ifyClose != 0 && closeTimer == nil { @@ -1491,7 +1510,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. - cleanupOnError := func(err *tcpip.Error) { + cleanupOnError := func(err tcpip.Error) { e.stack.Stats().TCP.CurrentConnected.Decrement() e.workerCleanup = true if err != nil { diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 1d1b01a6c..2d90246e4 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -15,11 +15,11 @@ package tcp_test import ( + "strings" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -37,7 +37,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { // Start connection attempt, it must fail. err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrNoRoute { + if _, ok := err.(*tcpip.ErrNoRoute); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } } @@ -49,7 +49,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -156,7 +156,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -391,7 +391,7 @@ func testV4Accept(t *testing.T, c *context.Context) { defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -415,8 +415,10 @@ func testV4Accept(t *testing.T, c *context.Context) { t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) } + var r strings.Reader data := "Don't panic" - nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + r.Reset(data) + nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() tcp = header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { @@ -523,7 +525,7 @@ func TestV6AcceptOnV6(t *testing.T) { defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress _, _, err := c.EP.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -547,7 +549,7 @@ func TestV4AcceptOnV4(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) @@ -611,7 +613,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -633,7 +635,7 @@ func TestV4ListenCloseOnV4(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ea509ac73..6e4e26c39 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -386,12 +386,12 @@ type endpoint struct { // hardError is meaningful only when state is stateError. It stores the // error to be returned when read/write syscalls are called and the // endpoint is in this state. hardError is protected by endpoint mu. - hardError *tcpip.Error `state:".(string)"` + hardError tcpip.Error // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // rcvReadMu synchronizes calls to Read. // @@ -557,7 +557,6 @@ type endpoint struct { // When the send side is closed, the protocol goroutine is notified via // sndCloseWaker, and sndClosed is set to true. sndBufMu sync.Mutex `state:"nosave"` - sndBufSize int sndBufUsed int sndClosed bool sndBufInQueue seqnum.Size @@ -869,7 +868,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, - sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), keepalive: keepalive{ // Linux defaults. @@ -882,13 +880,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) + e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - e.sndBufSize = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -967,7 +966,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() - if e.sndClosed || e.sndBufUsed < e.sndBufSize { + sndBufSize := e.getSendBufferSize() + if e.sndClosed || e.sndBufUsed < sndBufSize { result |= waiter.EventOut } e.sndBufMu.Unlock() @@ -1059,7 +1059,7 @@ func (e *endpoint) Close() { if isResetState { // Close the endpoint without doing full shutdown and // send a RST. - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) e.closeNoShutdownLocked() // Wake up worker to close the endpoint. @@ -1087,7 +1087,7 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1161,7 +1161,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1293,14 +1293,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Preconditions: e.mu must be held to call this function. -func (e *endpoint) hardErrorLocked() *tcpip.Error { +func (e *endpoint) hardErrorLocked() tcpip.Error { err := e.hardError e.hardError = nil return err } // Preconditions: e.mu must be held to call this function. -func (e *endpoint) lastErrorLocked() *tcpip.Error { +func (e *endpoint) lastErrorLocked() tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1309,7 +1309,7 @@ func (e *endpoint) lastErrorLocked() *tcpip.Error { } // LastError implements tcpip.Endpoint.LastError. -func (e *endpoint) LastError() *tcpip.Error { +func (e *endpoint) LastError() tcpip.Error { e.LockUser() defer e.UnlockUser() if err := e.hardErrorLocked(); err != nil { @@ -1319,7 +1319,7 @@ func (e *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (e *endpoint) UpdateLastError(err *tcpip.Error) { +func (e *endpoint) UpdateLastError(err tcpip.Error) { e.LockUser() e.lastErrorMu.Lock() e.lastError = err @@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvReadMu.Lock() defer e.rcvReadMu.Unlock() @@ -1337,7 +1337,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // can remove segments from the list through commitRead(). first, last, serr := e.startRead() if serr != nil { - if serr == tcpip.ErrClosedForReceive { + if _, ok := serr.(*tcpip.ErrClosedForReceive); ok { e.stats.ReadErrors.ReadClosed.Increment() } return tcpip.ReadResult{}, serr @@ -1377,7 +1377,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // If something is read, we must report it. Report error when nothing is read. if done == 0 && err != nil { - return tcpip.ReadResult{}, tcpip.ErrBadBuffer + return tcpip.ReadResult{}, &tcpip.ErrBadBuffer{} } return tcpip.ReadResult{ Count: done, @@ -1389,7 +1389,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // inclusive range of segments that can be read. // // Precondition: e.rcvReadMu must be held. -func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { +func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1398,7 +1398,7 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { // on a receive. It can expect to read any data after the handshake // is complete. RFC793, section 3.9, p58. if e.EndpointState() == StateSynSent { - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } // The endpoint can be read if it's connected, or if it's already closed @@ -1414,17 +1414,17 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { if err := e.hardErrorLocked(); err != nil { return nil, nil, err } - return nil, nil, tcpip.ErrClosedForReceive + return nil, nil, &tcpip.ErrClosedForReceive{} } e.stats.ReadErrors.NotConnected.Increment() - return nil, nil, tcpip.ErrNotConnected + return nil, nil, &tcpip.ErrNotConnected{} } if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { - return nil, nil, tcpip.ErrClosedForReceive + return nil, nil, &tcpip.ErrClosedForReceive{} } - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } return e.rcvList.Front(), e.rcvList.Back(), nil @@ -1476,106 +1476,117 @@ func (e *endpoint) commitRead(done int) *segment { // moment. If the endpoint is not writable then it returns an error // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu -func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { +func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: if err := e.hardErrorLocked(); err != nil { return 0, err } - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} case !s.connecting() && !s.connected(): - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} case s.connecting(): // As per RFC793, page 56, a send request arriving when in connecting // state, can be queued to be completed after the state becomes // connected. Return an error code for the caller of endpoint Write to // try again, until the connection handshake is complete. - return 0, tcpip.ErrWouldBlock + return 0, &tcpip.ErrWouldBlock{} } // Check if the connection has already been closed for sends. if e.sndClosed { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } - avail := e.sndBufSize - e.sndBufUsed + sndBufSize := e.getSendBufferSize() + avail := sndBufSize - e.sndBufUsed if avail <= 0 { - return 0, tcpip.ErrWouldBlock + return 0, &tcpip.ErrWouldBlock{} } return avail, nil } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. e.LockUser() - e.sndBufMu.Lock() - - avail, err := e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, err - } - - // We can release locks while copying data. - // - // This is not possible if atomic is set, because we can't allow the - // available buffer space to be consumed by some other caller while we - // are copying data in. - if !opts.Atomic { - e.sndBufMu.Unlock() - e.UnlockUser() - } - - // Fetch data. - v, perr := p.Payload(avail) - if perr != nil || len(v) == 0 { - // Note that perr may be nil if len(v) == 0. - if opts.Atomic { - e.sndBufMu.Unlock() - e.UnlockUser() - } - return 0, perr - } + defer e.UnlockUser() - if !opts.Atomic { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - e.LockUser() + nextSeg, n, err := func() (*segment, int, tcpip.Error) { e.sndBufMu.Lock() + defer e.sndBufMu.Unlock() + avail, err := e.isEndpointWritableLocked() if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() - return 0, err + return nil, 0, err } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + v, err := func() ([]byte, tcpip.Error) { + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + defer e.sndBufMu.Lock() + + e.UnlockUser() + defer e.LockUser() + } + + // Fetch data. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return nil, nil + } + v := make([]byte, avail) + if _, err := io.ReadFull(p, v); err != nil { + return nil, &tcpip.ErrBadBuffer{} + } + return v, nil + }() + if len(v) == 0 || err != nil { + return nil, 0, err + } + + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.stats.WriteErrors.WriteClosed.Increment() + return nil, 0, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } - } - // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() + // Add data to the send queue. + s := newOutgoingSegment(e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) - // Do the work inline. - e.handleWrite() - e.UnlockUser() - return int64(len(v)), nil + return e.drainSendQueueLocked(), len(v), nil + }() + if err != nil { + return 0, err + } + e.sendData(nextSeg) + return int64(n), nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1682,8 +1693,16 @@ func (e *endpoint) OnCorkOptionSet(v bool) { } } +func (e *endpoint) getSendBufferSize() int { + sndBufSize, err := e.ops.GetSendBufferSize() + if err != nil { + panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err)) + } + return int(sndBufSize) +} + // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 const inetECNMask = 3 @@ -1711,7 +1730,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.MaxSegOption: userMSS := v if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } e.LockUser() e.userMSS = uint16(userMSS) @@ -1722,7 +1741,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Return not supported if attempting to set this option to // anything other than path MTU discovery disabled. if v != tcpip.PMTUDiscoveryDont { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } case tcpip.ReceiveBufferSizeOption: @@ -1775,31 +1794,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss tcpip.TCPSendBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) - } - - if v > ss.Max { - v = ss.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < ss.Min { - v = ss.Min - } - } else { - v = math.MaxInt32 - } - - e.sndBufMu.Lock() - e.sndBufSize = v - e.sndBufMu.Unlock() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1807,7 +1801,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.TCPSynCountOption: if v < 1 || v > 255 { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } e.LockUser() e.maxSynRetries = uint8(v) @@ -1823,7 +1817,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: e.UnlockUser() - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -1844,7 +1838,7 @@ func (e *endpoint) HasNIC(id int32) bool { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() @@ -1890,7 +1884,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. - return tcpip.ErrNoSuchFile + return &tcpip.ErrNoSuchFile{} case *tcpip.TCPLingerTimeoutOption: e.LockUser() @@ -1933,13 +1927,13 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } // readyReceiveSize returns the number of bytes ready to be received. -func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { +func (e *endpoint) readyReceiveSize() (int, tcpip.Error) { e.LockUser() defer e.UnlockUser() // The endpoint cannot be in listen state. if e.EndpointState() == StateListen { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } e.rcvListMu.Lock() @@ -1949,7 +1943,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.KeepaliveCountOption: e.keepalive.Lock() @@ -1985,12 +1979,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - v := e.sndBufSize - e.sndBufMu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvListMu.Lock() v := e.rcvBufSize @@ -2019,24 +2007,38 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return 1, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} + } +} + +func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { + info := tcpip.TCPInfoOption{} + e.LockUser() + snd := e.snd + if snd != nil { + // We do not calculate RTT before sending the data packets. If + // the connection did not send and receive data, then RTT will + // be zero. + snd.rtt.Lock() + info.RTT = snd.rtt.srtt + info.RTTVar = snd.rtt.rttvar + snd.rtt.Unlock() + + info.RTO = snd.rto + info.CcState = snd.state + info.SndSsthresh = uint32(snd.sndSsthresh) + info.SndCwnd = uint32(snd.sndCwnd) + info.ReorderSeen = snd.rc.reorderSeen } + e.UnlockUser() + return info } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.TCPInfoOption: - *o = tcpip.TCPInfoOption{} - e.LockUser() - snd := e.snd - e.UnlockUser() - if snd != nil { - snd.rtt.Lock() - o.RTT = snd.rtt.srtt - o.RTTVar = snd.rtt.rttvar - snd.rtt.Unlock() - } + *o = e.getTCPInfo() case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() @@ -2082,14 +2084,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } return nil } // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -2098,18 +2100,20 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect connects the endpoint to its peer. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { err := e.connect(addr, true, true) - if err != nil && !err.IgnoreStats() { - // Connect failed. Let's wake up any waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() + if err != nil { + if !err.IgnoreStats() { + // Connect failed. Let's wake up any waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } } return err } @@ -2120,7 +2124,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -2139,7 +2143,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return nil } // Otherwise return that it's already connected. - return tcpip.ErrAlreadyConnected + return &tcpip.ErrAlreadyConnected{} } nicID := addr.NIC @@ -2152,7 +2156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if nicID != 0 && nicID != e.boundNICID { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } nicID = e.boundNICID @@ -2164,16 +2168,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc case StateConnecting, StateSynSent, StateSynRecv: // A connection request has already been issued but hasn't completed // yet. - return tcpip.ErrAlreadyConnecting + return &tcpip.ErrAlreadyConnecting{} case StateError: if err := e.hardErrorLocked(); err != nil { return err } - return tcpip.ErrConnectionAborted + return &tcpip.ErrConnectionAborted{} default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // Find a route to the desired destination. @@ -2190,7 +2194,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -2229,12 +2233,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { - if err != tcpip.ErrPortInUse || !reuse { + if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } transEPID := e.ID @@ -2278,9 +2282,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) - if err == tcpip.ErrPortInUse { + if _, ok := err.(*tcpip.ErrPortInUse); ok { return false, nil } return false, err @@ -2335,23 +2339,23 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } - return tcpip.ErrConnectStarted + return &tcpip.ErrConnectStarted{} } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection to its // peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() return e.shutdownLocked(flags) } -func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.shutdownFlags |= flags switch { case e.EndpointState().connected(): @@ -2366,7 +2370,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) // Wake up worker to terminate loop. e.notifyProtocolGoroutine(notifyTickleWorker) return nil @@ -2380,7 +2384,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // Already closed. e.sndBufMu.Unlock() if e.EndpointState() == StateTimeWait { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } return nil } @@ -2413,22 +2417,24 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } return nil default: - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } } // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) tcpip.Error { err := e.listen(backlog) - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() + if err != nil { + if !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } } return err } -func (e *endpoint) listen(backlog int) *tcpip.Error { +func (e *endpoint) listen(backlog int) tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -2446,7 +2452,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // Adjust the size of the channel iff we can fix // existing pending connections into the new one. if len(e.acceptedChan) > backlog { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } if cap(e.acceptedChan) == backlog { return nil @@ -2478,11 +2484,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // Endpoint must be bound before it can transition to listen mode. if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } @@ -2518,7 +2524,7 @@ func (e *endpoint) startAcceptedLoop() { // to an endpoint previously set to listen mode. // // addr if not-nil will contain the peer address of the returned endpoint. -func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2527,7 +2533,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. e.rcvListMu.Unlock() // Endpoint must be in listen state before it can accept connections. if rcvClosed || e.EndpointState() != StateListen { - return nil, nil, tcpip.ErrInvalidEndpointState + return nil, nil, &tcpip.ErrInvalidEndpointState{} } // Get the new accepted endpoint. @@ -2538,7 +2544,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. case n = <-e.acceptedChan: e.acceptCond.Signal() default: - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } if peerAddr != nil { *peerAddr = n.getRemoteAddress() @@ -2547,19 +2553,19 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. } // Bind binds the endpoint to a specific local port and optionally address. -func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { +func (e *endpoint) Bind(addr tcpip.FullAddress) (err tcpip.Error) { e.LockUser() defer e.UnlockUser() return e.bindLocked(addr) } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. if e.EndpointState() != StateInitial { - return tcpip.ErrAlreadyBound + return &tcpip.ErrAlreadyBound{} } e.BindAddr = addr.Addr @@ -2587,7 +2593,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { if len(addr.Addr) != 0 { nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nic == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } e.ID.LocalAddress = addr.Addr } @@ -2604,7 +2610,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2628,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2640,12 +2646,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.LockUser() defer e.UnlockUser() if !e.EndpointState().connected() { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return e.getRemoteAddress(), nil @@ -2677,7 +2683,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } -func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -2726,26 +2732,27 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlAddressUnreachable: - e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNetworkUnreachable{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } // updateSndBufferUsage is called by the protocol goroutine when room opens up // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { + sendBufferSize := e.getSendBufferSize() e.sndBufMu.Lock() - notify := e.sndBufUsed >= e.sndBufSize>>1 + notify := e.sndBufUsed >= sendBufferSize>>1 e.sndBufUsed -= v - // We only notify when there is half the sndBufSize available after + // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < e.sndBufSize>>1 + notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 e.sndBufMu.Unlock() if notify { @@ -2957,8 +2964,9 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() // Copy endpoint send state. + sndBufSize := e.getSendBufferSize() e.sndBufMu.Lock() - s.SndBufSize = e.sndBufSize + s.SndBufSize = sndBufSize s.SndBufUsed = e.sndBufUsed s.SndClosed = e.sndClosed s.SndBufInQueue = e.sndBufInQueue @@ -3023,12 +3031,16 @@ func (e *endpoint) completeState() stack.TCPEndpointState { rc := &e.snd.rc s.Sender.RACKState = stack.TCPRACKState{ - XmitTime: rc.xmitTime, - EndSequence: rc.endSequence, - FACK: rc.fack, - RTT: rc.rtt, - Reord: rc.reorderSeen, - DSACKSeen: rc.dsackSeen, + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + Reord: rc.reorderSeen, + DSACKSeen: rc.dsackSeen, + ReoWnd: rc.reoWnd, + ReoWndIncr: rc.reoWndIncr, + ReoWndPersist: rc.reoWndPersist, + RTTSeq: rc.rttSeq, } return s } @@ -3103,3 +3115,17 @@ func (e *endpoint) Wait() { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// GetTCPSendBufferLimits is used to get send buffer size limits for TCP. +func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { + var ss tcpip.TCPSendBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.SendBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ba67176b5..c21dbc682 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,9 +55,11 @@ func (e *endpoint) beforeSave() { case epState.connected() || epState.handshake(): if !e.route.HasSaveRestoreCapability() { if !e.route.HasDisconncetOkCapability() { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) + panic(&tcpip.ErrSaveRejection{ + Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), + }) } - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) e.mu.Unlock() e.Close() e.mu.Lock() @@ -179,14 +181,16 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + sendBufferSize := e.getSendBufferSize() + if sendBufferSize < ss.Min || sendBufferSize > ss.Max { + panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max)) } } @@ -228,7 +232,8 @@ func (e *endpoint) Resume(s *stack.Stack) { // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } e.mu.Lock() @@ -265,7 +270,8 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { + err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -292,24 +298,6 @@ func (e *endpoint) Resume(s *stack.Stack) { } } -// saveLastError is invoked by stateify. -func (e *endpoint) saveLastError() string { - if e.lastError == nil { - return "" - } - - return e.lastError.String() -} - -// loadLastError is invoked by stateify. -func (e *endpoint) loadLastError(s string) { - if s == "" { - return - } - - e.lastError = tcpip.StringToError(s) -} - // saveRecentTSTime is invoked by stateify. func (e *endpoint) saveRecentTSTime() unixTime { return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} @@ -320,24 +308,6 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { e.recentTSTime = time.Unix(unix.second, unix.nano) } -// saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { - return "" - } - - return e.hardError.String() -} - -// loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { - if s == "" { - return - } - - e.hardError = tcpip.StringToError(s) -} - // saveMeasureTime is invoked by stateify. func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 596178625..2f9fe7ee0 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -143,12 +143,12 @@ func (r *ForwarderRequest) Complete(sendReset bool) { // CreateEndpoint creates a TCP endpoint for the connection request, performing // the 3-way handshake in the process. -func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { r.mu.Lock() defer r.mu.Unlock() if r.segment == nil { - return nil, tcpip.ErrInvalidEndpointState + return nil, &tcpip.ErrInvalidEndpointState{} } f := r.forwarder diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 1720370c9..04012cd40 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -161,13 +161,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } @@ -178,7 +178,7 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given tcp // packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { h := header.TCP(v) return h.SourcePort(), h.DestinationPort(), nil } @@ -216,7 +216,7 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, // replyWithReset replies to the given segment with a reset segment. // // If the passed TTL is 0, then the route's default TTL will be used. -func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err @@ -261,7 +261,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPSACKEnabled: p.mu.Lock() @@ -283,7 +283,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.sendBufferSize = *v @@ -292,7 +292,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.recvBufferSize = *v @@ -310,7 +310,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi } // linux returns ENOENT when an invalid congestion control // is specified. - return tcpip.ErrNoSuchFile + return &tcpip.ErrNoSuchFile{} case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() @@ -340,7 +340,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPTimeWaitReuseOption: if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.timeWaitReuse = *v @@ -381,7 +381,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPSynRetriesOption: if *v < 1 || *v > 255 { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.synRetries = uint8(*v) @@ -389,12 +389,12 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPSACKEnabled: p.mu.RLock() @@ -493,7 +493,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 307bacca5..d85cb405a 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -22,12 +22,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) -// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as -// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. -// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is -// 1, PTO is inflated by WCDelAckT time to compensate for a potential long -// delayed ACK timer at the receiver. -const wcDelayedACKTimeout = 200 * time.Millisecond +const ( + // wcDelayedACKTimeout is the recommended maximum delayed ACK timer + // value as defined in the RFC. It stands for worst case delayed ACK + // timer (WCDelAckT). When FlightSize is 1, PTO is inflated by + // WCDelAckT time to compensate for a potential long delayed ACK timer + // at the receiver. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. + wcDelayedACKTimeout = 200 * time.Millisecond + + // tcpRACKRecoveryThreshold is the number of loss recoveries for which + // the reorder window is inflated and after that the reorder window is + // reset to its initial value of minRTT/4. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2. + tcpRACKRecoveryThreshold = 16 +) // RACK is a loss detection algorithm used in TCP to detect packet loss and // reordering using transmission timestamp of the packets instead of packet or @@ -44,6 +53,11 @@ type rackControl struct { // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value + // exitedRecovery indicates if the connection is exiting loss recovery. + // This flag is set if the sender is leaving the recovery after + // receiving an ACK and is reset during updating of reorder window. + exitedRecovery bool + // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value @@ -51,15 +65,30 @@ type rackControl struct { // minRTT is the estimated minimum RTT of the connection. minRTT time.Duration + // reorderSeen indicates if reordering has been detected on this + // connection. + reorderSeen bool + + // reoWnd is the reordering window time used for recording packet + // transmission times. It is used to defer the moment at which RACK + // marks a packet lost. + reoWnd time.Duration + + // reoWndIncr is the multiplier applied to adjust reorder window. + reoWndIncr uint8 + + // reoWndPersist is the number of loss recoveries before resetting + // reorder window. + reoWndPersist int8 + // rtt is the RTT of the most recently delivered packet on the // connection (either cumulatively acknowledged or selectively // acknowledged) that was not marked invalid as a possible spurious // retransmission. rtt time.Duration - // reorderSeen indicates if reordering has been detected on this - // connection. - reorderSeen bool + // rttSeq is the SND.NXT when rtt is updated. + rttSeq seqnum.Value // xmitTime is the latest transmission timestamp of rackControl.seg. xmitTime time.Time `state:".(unixTime)"` @@ -75,29 +104,36 @@ type rackControl struct { // tlpHighRxt the value of sender.sndNxt at the time of sending // a TLP retransmission. tlpHighRxt seqnum.Value + + // snd is a reference to the sender. + snd *sender } // init initializes RACK specific fields. -func (rc *rackControl) init() { +func (rc *rackControl) init(snd *sender, iss seqnum.Value) { + rc.fack = iss + rc.reoWndIncr = 1 + rc.snd = snd rc.probeTimer.init(&rc.probeWaker) } // update will update the RACK related fields when an ACK has been received. -// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 -func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) { +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 +func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := time.Now().Sub(seg.xmitTime) + tsOffset := rc.snd.ep.tsOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: - // 1. When Timestamping option is available, if the TSVal is less than the - // transmit time of the most recent retransmitted packet. + // 1. When Timestamping option is available, if the TSVal is less than + // the transmit time of the most recent retransmitted packet. // 2. When RTT calculated for the packet is less than the smoothed RTT // for the connection. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // step 2 if seg.xmitCount > 1 { if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { - if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) { + if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) { return } } @@ -149,9 +185,8 @@ func (rc *rackControl) detectReorder(seg *segment) { } } -// setDSACKSeen updates rack control if duplicate SACK is seen by the connection. -func (rc *rackControl) setDSACKSeen() { - rc.dsackSeen = true +func (rc *rackControl) setDSACKSeen(dsackSeen bool) { + rc.dsackSeen = dsackSeen } // shouldSchedulePTO dictates whether we should schedule a PTO or not. @@ -162,7 +197,7 @@ func (s *sender) shouldSchedulePTO() bool { // The connection supports SACK. s.ep.sackPermitted && // The connection is not in loss recovery. - (s.state != RTORecovery && s.state != SACKRecovery) && + (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) && // The connection has no SACKed sequences in the SACK scoreboard. s.ep.scoreboard.Sacked() == 0 } @@ -193,7 +228,7 @@ func (s *sender) schedulePTO() { // probeTimerExpired is the same as TLP_send_probe() as defined in // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2. -func (s *sender) probeTimerExpired() *tcpip.Error { +func (s *sender) probeTimerExpired() tcpip.Error { if !s.rc.probeTimer.checkExpiration() { return nil } @@ -272,3 +307,82 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { } } } + +// updateRACKReorderWindow updates the reorder window. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +// * Step 4: Update RACK reordering window +// To handle the prevalent small degree of reordering, RACK.reo_wnd serves as +// an allowance for settling time before marking a packet lost. RACK starts +// initially with a conservative window of min_RTT/4. If no reordering has +// been observed RACK uses reo_wnd of zero during loss recovery, in order to +// retransmit quickly, or when the number of DUPACKs exceeds the classic +// DUPACKthreshold. +func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { + dsackSeen := rc.dsackSeen + snd := rc.snd + + // React to DSACK once per round trip. + // If SND.UNA < RACK.rtt_seq: + // RACK.dsack = false + if snd.sndUna.LessThan(rc.rttSeq) { + dsackSeen = false + } + + // If RACK.dsack: + // RACK.reo_wnd_incr += 1 + // RACK.dsack = false + // RACK.rtt_seq = SND.NXT + // RACK.reo_wnd_persist = 16 + if dsackSeen { + rc.reoWndIncr++ + dsackSeen = false + rc.rttSeq = snd.sndNxt + rc.reoWndPersist = tcpRACKRecoveryThreshold + } else if rc.exitedRecovery { + // Else if exiting loss recovery: + // RACK.reo_wnd_persist -= 1 + // If RACK.reo_wnd_persist <= 0: + // RACK.reo_wnd_incr = 1 + rc.reoWndPersist-- + if rc.reoWndPersist <= 0 { + rc.reoWndIncr = 1 + } + rc.exitedRecovery = false + } + + // Reorder window is zero during loss recovery, or when the number of + // DUPACKs exceeds the classic DUPACKthreshold. + // If RACK.reord is FALSE: + // If in loss recovery: (If in fast or timeout recovery) + // RACK.reo_wnd = 0 + // Return + // Else if RACK.pkts_sacked >= RACK.dupthresh: + // RACK.reo_wnd = 0 + // return + if !rc.reorderSeen { + if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery { + rc.reoWnd = 0 + return + } + + if snd.sackedOut >= nDupAckThreshold { + rc.reoWnd = 0 + return + } + } + + // Calculate reorder window. + // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr + // RACK.reo_wnd = min(RACK.reo_wnd, SRTT) + snd.rtt.Lock() + srtt := snd.rtt.srtt + snd.rtt.Unlock() + rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr)) + if srtt < rc.reoWnd { + rc.reoWnd = srtt + } +} + +func (rc *rackControl) exitRecovery() { + rc.exitedRecovery = true +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 405a6dce7..7a7c402c4 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -347,7 +347,7 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { r.ep.rcvListMu.Lock() rcvClosed := r.ep.rcvClosed || r.closed r.ep.rcvListMu.Unlock() @@ -395,7 +395,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - return true, tcpip.ErrConnectionAborted + return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { break @@ -424,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - return true, tcpip.ErrConnectionAborted + return true, &tcpip.ErrConnectionAborted{} } } @@ -443,7 +443,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { state := r.ep.EndpointState() closed := r.ep.closed diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 079d90848..dfc8fd248 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -48,28 +48,6 @@ const ( MaxRetries = 15 ) -// ccState indicates the current congestion control state for this sender. -type ccState int - -const ( - // Open indicates that the sender is receiving acks in order and - // no loss or dupACK's etc have been detected. - Open ccState = iota - // RTORecovery indicates that an RTO has occurred and the sender - // has entered an RTO based recovery phase. - RTORecovery - // FastRecovery indicates that the sender has entered FastRecovery - // based on receiving nDupAck's. This state is entered only when - // SACK is not in use. - FastRecovery - // SACKRecovery indicates that the sender has entered SACK based - // recovery. - SACKRecovery - // Disorder indicates the sender either received some SACK blocks - // or dupACK's. - Disorder -) - // congestionControl is an interface that must be implemented by any supported // congestion control algorithm. type congestionControl interface { @@ -204,7 +182,7 @@ type sender struct { maxSentAck seqnum.Value // state is the current state of congestion control for this endpoint. - state ccState + state tcpip.CongestionControlState // cc is the congestion control algorithm in use for this sender. cc congestionControl @@ -280,14 +258,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint highRxt: iss, rescueRxt: iss, }, - rc: rackControl{ - fack: iss, - }, gso: ep.gso != nil, } - s.rc.init() - if s.gso { s.ep.gso.MSS = uint16(maxPayloadSize) } @@ -295,6 +268,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.cc = s.initCongestionControl(ep.cc) s.lr = s.initLossRecovery() + s.rc.init(s, iss) // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. @@ -593,7 +567,7 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveRecovery() } - s.state = RTORecovery + s.state = tcpip.RTORecovery s.cc.HandleRTOExpired() // Mark the next segment to be sent as the first unacknowledged one and @@ -1018,7 +992,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { + if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { if s.sndCwnd > InitialCwnd { s.sndCwnd = InitialCwnd } @@ -1062,14 +1036,14 @@ func (s *sender) enterRecovery() { s.fr.highRxt = s.sndUna s.fr.rescueRxt = s.sndUna if s.ep.sackPermitted { - s.state = SACKRecovery + s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() // Set TLPRxtOut to false according to // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1. s.rc.tlpRxtOut = false return } - s.state = FastRecovery + s.state = tcpip.FastRecovery s.ep.stack.Stats().TCP.FastRecovery.Increment() } @@ -1080,7 +1054,6 @@ func (s *sender) leaveRecovery() { // Deflate cwnd. It had been artificially inflated when new dups arrived. s.sndCwnd = s.sndSsthresh - s.cc.PostRecovery() } @@ -1166,7 +1139,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { s.fr.highRxt = s.sndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() - s.state = Disorder + s.state = tcpip.Disorder return false } @@ -1217,11 +1190,13 @@ func (s *sender) isDupAck(seg *segment) bool { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { + s.rc.setDSACKSeen(false) + // Look for DSACK block. idx := 0 n := len(rcvdSeg.parsedOptions.SACKBlocks) if checkDSACK(rcvdSeg) { - s.rc.setDSACKSeen() + s.rc.setDSACKSeen(true) idx = 1 n-- } @@ -1242,7 +1217,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { for _, sb := range sackBlocks { for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { - s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) seg.acked = true s.sackedOut += s.pCount(seg, s.maxPayloadSize) @@ -1412,6 +1387,17 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { acked := s.sndUna.Size(ack) s.sndUna = ack + // The remote ACK-ing at least 1 byte is an indication that we have a + // full-duplex connection to the remote as the only way we will receive an + // ACK is if the remote received data that we previously sent. + // + // As of writing, linux seems to only confirm a route as reachable when + // forward progress is made which is indicated by an ACK that removes data + // from the retransmit queue. + if acked > 0 { + s.ep.route.ConfirmReachable() + } + ackLeft := acked originalOutstanding := s.outstanding for ackLeft > 0 { @@ -1435,7 +1421,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Update the RACK fields if SACK is enabled. if s.ep.sackPermitted && !seg.acked { - s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1464,7 +1450,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if !s.fr.active { s.cc.Update(originalOutstanding - s.outstanding) if s.fr.last.LessThan(s.sndUna) { - s.state = Open + s.state = tcpip.Open + // Update RACK when we are exiting fast or RTO + // recovery as described in the RFC + // draft-ietf-tcpm-rack-08 Section-7.2 Step 4. + s.rc.exitRecovery() } } @@ -1488,6 +1478,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } + // Update RACK reorder window. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // * Upon receiving an ACK: + // * Step 4: Update RACK reordering window + s.rc.updateRACKReorderWindow(rcvdSeg) + // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. if s.fr.active { @@ -1508,7 +1504,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // sendSegment sends the specified segment. -func (s *sender) sendSegment(seg *segment) *tcpip.Error { +func (s *sender) sendSegment(seg *segment) tcpip.Error { if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() @@ -1539,7 +1535,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. -func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index f7aaee23f..ced3a9c58 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -21,13 +21,13 @@ package tcp_test import ( + "bytes" "fmt" "math" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" @@ -42,14 +42,16 @@ func TestFastRecovery(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -207,14 +209,16 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -249,14 +253,16 @@ func TestCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -353,15 +359,16 @@ func TestCubicCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -462,19 +469,20 @@ func TestRetransmit(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in two shots. Packets will only be written at the // MTU size though. - half := data[:len(data)/2] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data[:len(data)/2]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } - half = data[len(data)/2:] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + r.Reset(data[len(data)/2:]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 342eb5eb8..a6a26b705 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -15,11 +15,12 @@ package tcp_test import ( + "bytes" + "fmt" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -61,14 +62,16 @@ func TestRACKUpdate(t *testing.T) { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(maxPayload) + data := make([]byte, maxPayload) for i := range data { data[i] = byte(i) } // Write the data. xmitTime = time.Now() - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -114,13 +117,15 @@ func TestRACKDetectReorder(t *testing.T) { }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(ackNumToVerify * maxPayload) + data := make([]byte, ackNumToVerify*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -141,17 +146,19 @@ func TestRACKDetectReorder(t *testing.T) { <-probeDone } -func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View { +func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(numPackets * maxPayload) + data := make([]byte, numPackets*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -528,3 +535,64 @@ func TestRACKWithInvalidDSACKBlock(t *testing.T) { // ACK before the test completes. <-probeDone } + +func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) { + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK detects DSACK. + n++ + if n < numACK { + return + } + + if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { + probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) + return + } + + if state.Sender.RACKState.ReoWndIncr != 1 { + probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr) + return + } + + if state.Sender.RACKState.ReoWndPersist > 0 { + probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist) + return + } + probeDone <- nil + }) +} + +func TestRACKCheckReorderWindow(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan error) + const ackNumToVerify = 3 + addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone) + + const numPackets = 7 + sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed packets [2-6] which indicates there is reordering + // in the connection. + bytesRead += 6 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + if err := <-probeDone; err != nil { + t.Fatalf("unexpected values for RACK variables: %v", err) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 6635bb815..5024bc925 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -15,6 +15,7 @@ package tcp_test import ( + "bytes" "fmt" "log" "reflect" @@ -22,7 +23,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -395,14 +395,16 @@ func TestSACKRecovery(t *testing.T) { createConnectedWithSACKAndTS(c) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 93683b921..da2730e27 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "math" + "strings" "testing" "time" @@ -26,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -48,7 +48,7 @@ type endpointTester struct { } // CheckReadError issues a read to the endpoint and checking for an error. -func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { +func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) { t.Helper() res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) if got != want { @@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for receive to be notified. select { case <-notifyRead: @@ -128,8 +128,11 @@ func TestGiveUpConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventHUp) defer wq.EventUnregister(&waitEntry) - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } // Close the connection, wait for completion. @@ -140,8 +143,11 @@ func TestGiveUpConnect(t *testing.T) { // Call Connect again to retreive the handshake failure status // and stats updates. - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrAborted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) + } } if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { @@ -194,8 +200,11 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { c.EP = ep want := stats.TCP.FailedConnectionAttempts.Value() + 1 - if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute) + { + err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) + } } if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { @@ -211,7 +220,7 @@ func TestCloseWithoutConnect(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -384,7 +393,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -925,8 +934,11 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} - if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("Connect(%+v): %s", connectAddr, err) + { + err := c.EP.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("Connect(%+v): %s", connectAddr, err) + } } // Receive SYN packet with our user supplied MSS. @@ -1347,10 +1359,9 @@ func TestTOSV4(t *testing.T) { testV4Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1396,10 +1407,9 @@ func TestTrafficClassV6(t *testing.T) { testV6Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1444,7 +1454,8 @@ func TestConnectBindToDevice(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("unexpected return value from Connect: %s", err) } @@ -1504,8 +1515,9 @@ func TestSynSent(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { - t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) + err := c.EP.Connect(addr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) } // Receive SYN packet. @@ -1550,9 +1562,9 @@ func TestSynSent(t *testing.T) { ept := endpointTester{c.EP} if test.reset { - ept.CheckReadError(t, tcpip.ErrConnectionRefused) + ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) } else { - ept.CheckReadError(t, tcpip.ErrAborted) + ept.CheckReadError(t, &tcpip.ErrAborted{}) } if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { @@ -1578,7 +1590,7 @@ func TestOutOfOrderReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send second half of data first, with seqnum 3 ahead of expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1603,7 +1615,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send the first 3 bytes now. c.SendPacket(data[:3], &context.Headers{ @@ -1642,7 +1654,7 @@ func TestOutOfOrderFlood(t *testing.T) { c.CreateConnected(789, 30000, rcvBufSz) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send 100 packets before the actual one that is expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1718,7 +1730,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1786,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1868,13 +1880,13 @@ func TestShutdownRead(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { t.Fatalf("Shutdown failed: %s", err) } - ept.CheckReadError(t, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) @@ -1893,7 +1905,7 @@ func TestFullWindowReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies // the provided buffer value by tcp.SegOverheadFactor to calculate the actual @@ -2054,7 +2066,7 @@ func TestNoWindowShrinking(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send a 1 byte payload so that we can record the current receive window. // Send a payload of half the size of rcvBufSize. @@ -2176,10 +2188,9 @@ func TestSimpleSend(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2217,10 +2228,9 @@ func TestZeroWindowSend(t *testing.T) { c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2285,10 +2295,9 @@ func TestScaledWindowConnect(t *testing.T) { }) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2317,10 +2326,9 @@ func TestNonScaledWindowConnect(t *testing.T) { c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2376,7 +2384,7 @@ func TestScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -2391,10 +2399,9 @@ func TestScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2450,7 +2457,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -2465,10 +2472,9 @@ func TestNonScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2632,9 +2638,10 @@ func TestSegmentMerging(t *testing.T) { // Send tcp.InitialCwnd number of segments to fill up // InitialWindow but don't ACK. That should prevent // anymore packets from going out. + var r bytes.Reader for i := 0; i < tcp.InitialCwnd; i++ { - view := buffer.NewViewFromBytes([]byte{0}) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset([]byte{0}) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2644,8 +2651,8 @@ func TestSegmentMerging(t *testing.T) { var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2714,8 +2721,9 @@ func TestDelay(t *testing.T) { var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2761,8 +2769,9 @@ func TestUndelay(t *testing.T) { allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2845,8 +2854,9 @@ func TestMSSNotDelayed(t *testing.T) { allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2894,10 +2904,9 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { data[i] = byte(i) } - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2963,7 +2972,7 @@ func TestSetTTL(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -2973,8 +2982,11 @@ func TestSetTTL(t *testing.T) { t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("unexpected return value from Connect: %s", err) + } } // Receive SYN packet. @@ -3034,7 +3046,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -3090,7 +3102,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -3115,9 +3127,9 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) { defer c.Cleanup() s := c.Stack() - ch := make(chan *tcpip.Error, 1) + ch := make(chan tcpip.Error, 1) f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err *tcpip.Error + var err tcpip.Error c.EP, err = r.CreateEndpoint(&c.WQ) ch <- err }) @@ -3146,7 +3158,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -3165,8 +3177,11 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventOut) defer c.WQ.EventUnregister(&we) - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } // Receive SYN packet. @@ -3276,22 +3291,23 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { - switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err { - case tcpip.ErrWouldBlock: + switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { + case *tcpip.ErrWouldBlock: select { case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset) + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrConnectionReset); !ok { + t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } break loop case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for reset to arrive") } - case tcpip.ErrConnectionReset: + case *tcpip.ErrConnectionReset: break loop default: - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } } @@ -3328,9 +3344,11 @@ func TestSendOnResetConnection(t *testing.T) { time.Sleep(1 * time.Second) // Try to write. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) + var r bytes.Reader + r.Reset(make([]byte, 10)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if _, ok := err.(*tcpip.ErrConnectionReset); !ok { + t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } } @@ -3352,7 +3370,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventHUp) defer c.WQ.EventUnregister(&waitEntry) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3409,7 +3429,9 @@ func TestMaxRTO(t *testing.T) { c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3458,7 +3480,9 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) } - if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(make([]byte, tc.size)) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } pkt := c.GetPacket() @@ -3595,8 +3619,10 @@ func TestFinWithNoPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3667,9 +3693,11 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // Write enough segments to fill the congestion window before ACK'ing // any of them. - view := buffer.NewView(10) + view := make([]byte, 10) + var r bytes.Reader for i := tcp.InitialCwnd; i > 0; i-- { - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } } @@ -3754,8 +3782,10 @@ func TestFinWithPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3781,7 +3811,8 @@ func TestFinWithPendingData(t *testing.T) { }) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3841,8 +3872,10 @@ func TestFinWithPartialAck(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3879,7 +3912,8 @@ func TestFinWithPartialAck(t *testing.T) { ) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3985,8 +4019,10 @@ func scaledSendWindow(t *testing.T, scale uint8) { }) // Send some data. Check that it's capped by the window size. - view := buffer.NewView(65535) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 65535) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4170,7 +4206,7 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -4249,10 +4285,13 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. - ept.CheckReadError(t, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) var buf bytes.Buffer - if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) + { + _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) + if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { + t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) + } } } @@ -4263,7 +4302,7 @@ func TestReusePort(t *testing.T) { defer c.Cleanup() // First case, just an endpoint that was bound. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { t.Fatalf("NewEndpoint failed; %s", err) @@ -4293,8 +4332,11 @@ func TestReusePort(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } c.EP.Close() @@ -4351,9 +4393,9 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + s, err := ep.SocketOptions().GetSendBufferSize() if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) + t.Fatalf("GetSendBufferSize failed: %s", err) } if int(s) != v { @@ -4459,9 +4501,7 @@ func TestMinMaxBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(149, true) checkSendBufferSize(t, ep, 300) @@ -4473,9 +4513,7 @@ func TestMinMaxBufferSizes(t *testing.T) { // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) @@ -4505,11 +4543,11 @@ func TestBindToDeviceOption(t *testing.T) { testActions := []struct { name string setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error + setBindToDeviceError tcpip.Error getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, {"BindToExistent", nicIDPtr(321), nil, 321}, {"UnbindToDevice", nicIDPtr(0), nil, 0}, } @@ -4529,7 +4567,7 @@ func TestBindToDeviceOption(t *testing.T) { } } -func makeStack() (*stack.Stack, *tcpip.Error) { +func makeStack() (*stack.Stack, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -4599,8 +4637,11 @@ func TestSelfConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventOut) defer wq.EventUnregister(&waitEntry) - if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } <-notifyCh @@ -4610,9 +4651,9 @@ func TestSelfConnect(t *testing.T) { // Write something. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4752,9 +4793,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("Bind(%d) failed: %s", i, err) } } - want := tcpip.ErrConnectStarted + var want tcpip.Error = &tcpip.ErrConnectStarted{} if collides { - want = tcpip.ErrNoPortAvailable + want = &tcpip.ErrNoPortAvailable{} } if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) @@ -4785,12 +4826,13 @@ func TestPathMTUDiscovery(t *testing.T) { // Send 3200 bytes of data. const writeSize = 3200 - data := buffer.NewView(writeSize) + data := make([]byte, writeSize) for i := range data { data[i] = byte(i) } - - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4878,11 +4920,11 @@ func TestTCPEndpointProbe(t *testing.T) { func TestStackSetCongestionControl(t *testing.T) { testCases := []struct { cc tcpip.CongestionControlOption - err *tcpip.Error + err tcpip.Error }{ {"reno", nil}, {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, } for _, tc := range testCases { @@ -4964,11 +5006,11 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { func TestEndpointSetCongestionControl(t *testing.T) { testCases := []struct { cc tcpip.CongestionControlOption - err *tcpip.Error + err tcpip.Error }{ {"reno", nil}, {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, } for _, connected := range []bool{false, true} { @@ -4978,7 +5020,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5074,12 +5116,14 @@ func TestKeepalive(t *testing.T) { // Check that the connection is still alive. ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5163,7 +5207,7 @@ func TestKeepalive(t *testing.T) { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) } - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) @@ -5270,7 +5314,7 @@ func TestListenBacklogFull(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5313,7 +5357,7 @@ func TestListenBacklogFull(t *testing.T) { for i := 0; i < listenBacklog; i++ { _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5330,7 +5374,7 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5342,7 +5386,7 @@ func TestListenBacklogFull(t *testing.T) { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5358,7 +5402,9 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { @@ -5583,7 +5629,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5658,7 +5704,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { defer c.WQ.EventUnregister(&we) newEP, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5674,7 +5720,9 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() tcp = header.TCP(header.IPv4(pkt).Payload()) if string(tcp.Payload()) != data { @@ -5692,7 +5740,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5733,7 +5781,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { defer c.WQ.EventUnregister(&we) _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5749,7 +5797,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5763,7 +5811,7 @@ func TestSYNRetransmit(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5807,7 +5855,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5882,12 +5930,13 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { }) newEP, _, err := c.EP.Accept(nil) - - if err != nil && err != tcpip.ErrWouldBlock { + switch err.(type) { + case nil, *tcpip.ErrWouldBlock: + default: t.Fatalf("Accept failed: %s", err) } - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -5908,7 +5957,9 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil { + var r strings.Reader + r.Reset(data) + if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5953,7 +6004,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { // Verify that there is only one acceptable connection at this point. _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6023,7 +6074,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6055,7 +6106,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } ept := endpointTester{ep} - ept.CheckReadError(t, tcpip.ErrNotConnected) + ept.CheckReadError(t, &tcpip.ErrNotConnected{}) if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) } @@ -6075,7 +6126,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { defer wq.EventUnregister(&we) aep, _, err := ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6091,8 +6142,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected) + { + err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { + t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) + } } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { @@ -6211,7 +6265,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // window increases to the full available buffer size. for { _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { break } } @@ -6335,7 +6389,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalCopied := 0 for { res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { break } totalCopied += res.Count @@ -6527,7 +6581,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6646,7 +6700,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6753,7 +6807,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6843,7 +6897,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Try to accept the connection. c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6917,7 +6971,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -7067,7 +7121,7 @@ func TestTCPCloseWithData(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -7103,10 +7157,10 @@ func TestTCPCloseWithData(t *testing.T) { // Now write a few bytes and then close the endpoint. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7204,8 +7258,10 @@ func TestTCPUserTimeout(t *testing.T) { } // Send some data and wait before ACKing it. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7256,7 +7312,7 @@ func TestTCPUserTimeout(t *testing.T) { ) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) @@ -7300,7 +7356,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { // Check that the connection is still alive. ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Now receive 1 keepalives, but don't ACK it. b := c.GetPacket() @@ -7339,7 +7395,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { ), ) - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) } @@ -7494,8 +7550,9 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) + _, _, err := c.EP.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) } // Send data. This should result in an acceptable endpoint. @@ -7552,8 +7609,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) + _, _, err := c.EP.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7675,13 +7733,13 @@ func TestSetStackTimeWaitReuse(t *testing.T) { s := c.Stack() testCases := []struct { v int - err *tcpip.Error + err tcpip.Error }{ {int(tcpip.TCPTimeWaitReuseDisabled), nil}, {int(tcpip.TCPTimeWaitReuseGlobal), nil}, {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue}, - {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, + {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, } for _, tc := range testCases { diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index b65091c3c..5a9745ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -22,7 +22,6 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -152,10 +151,10 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS // Now send some data and validate that timestamp is echoed correctly in the ACK. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } @@ -215,10 +214,10 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd // Now send some data with the accepted connection endpoint and validate // that no timestamp option is sent in the TCP segment. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ee55f030c..b1cb9a324 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -586,7 +586,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int // is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 // only endpoint instead of a default dual stack socket. func (c *Context) CreateV6Endpoint(v6only bool) { - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) @@ -689,7 +689,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { + err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { c.t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -749,7 +750,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // Create creates a TCP endpoint. func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) @@ -887,7 +888,7 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { // It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK // does not carry an option that was not requested. func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) @@ -903,7 +904,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} err = c.EP.Connect(testFullAddr) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) } // Receive SYN packet. @@ -1054,7 +1055,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9f9b3d510..31a5ddce9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -97,9 +97,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int + mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. state EndpointState @@ -111,8 +109,8 @@ type endpoint struct { multicastNICID tcpip.NICID portFlags ports.Flags - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // Values used to reserve a port or register a transport endpoint. // (which ever happens first). @@ -176,18 +174,18 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) e.ops.SetMulticastLoop(true) + e.ops.SetSendBufferSize(32*1024, false /* notify */) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -217,7 +215,7 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } -func (e *endpoint) LastError() *tcpip.Error { +func (e *endpoint) LastError() tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() @@ -227,7 +225,7 @@ func (e *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (e *endpoint) UpdateLastError(err *tcpip.Error) { +func (e *endpoint) UpdateLastError(err tcpip.Error) { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -246,7 +244,7 @@ func (e *endpoint) Close() { switch e.EndpointState() { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -284,7 +282,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err } @@ -292,10 +290,10 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult e.rcvMu.Lock() if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -342,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil @@ -353,7 +351,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.EndpointState() { case StateInitial: case StateConnected: @@ -361,11 +359,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi case StateBound: if to == nil { - return false, tcpip.ErrDestinationRequired + return false, &tcpip.ErrDestinationRequired{} } return false, nil default: - return false, tcpip.ErrInvalidEndpointState + return false, &tcpip.ErrInvalidEndpointState{} } e.mu.RUnlock() @@ -391,7 +389,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { localAddr := e.ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -417,18 +415,18 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -438,14 +436,14 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { if err := e.LastError(); err != nil { return 0, err } // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } to := opts.To @@ -461,7 +459,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } // Prepare for write. @@ -482,9 +480,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } nicID = e.BindNICID @@ -492,7 +493,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc if to.Port == 0 { // Port 0 is an invalid port to send to. - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } dst, netProto, err := e.checkV4MappedLocked(*to) @@ -511,19 +512,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, tcpip.ErrBroadcastDisabled + return 0, &tcpip.ErrBroadcastDisabled{} } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. so := e.SocketOptions() if so.GetRecvError() { so.QueueLocalErr( - tcpip.ErrMessageTooLong, + &tcpip.ErrMessageTooLong{}, route.NetProto, header.UDPMaximumPacketSize, tcpip.FullAddress{ @@ -534,7 +535,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc v, ) } - return 0, tcpip.ErrMessageTooLong + return 0, &tcpip.ErrMessageTooLong{} } ttl := e.ttl @@ -584,13 +585,13 @@ func (e *endpoint) OnReusePortSet(v bool) { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.MTUDiscoverOption: // Return not supported if the value is not disabling path // MTU discovery. if v != tcpip.PMTUDiscoveryDont { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } case tcpip.MulticastTTLOption: @@ -632,25 +633,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvBufSizeMax = v e.mu.Unlock() return nil - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) - } - - if v < ss.Min { - v = ss.Min - } - if v > ss.Max { - v = ss.Max - } - - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -661,7 +643,7 @@ func (e *endpoint) HasNIC(id int32) bool { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() @@ -683,17 +665,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { if nic != 0 { if !e.stack.CheckNIC(nic) { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } else { nic = e.stack.CheckLocalAddress(0, netProto, addr) if nic == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } if e.BindNICID != 0 && e.BindNICID != nic { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } e.multicastNICID = nic @@ -701,7 +683,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.AddMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } nicID := v.NIC @@ -717,7 +699,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} @@ -726,7 +708,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { defer e.mu.Unlock() if _, ok := e.multicastMemberships[memToInsert]; ok { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { @@ -737,7 +719,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } nicID := v.NIC @@ -752,7 +734,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} @@ -761,7 +743,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { defer e.mu.Unlock() if _, ok := e.multicastMemberships[memToRemove]; !ok { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { @@ -777,7 +759,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.IPv4TOSOption: e.mu.RLock() @@ -811,12 +793,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax @@ -830,12 +806,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() @@ -846,14 +822,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { e.mu.Unlock() default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), Data: data, @@ -903,7 +879,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -912,7 +888,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (e *endpoint) Disconnect() *tcpip.Error { +func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -930,12 +906,12 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { - var err *tcpip.Error + var err tcpip.Error id = stack.TransportEndpointID{ LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, } - id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { return err } @@ -950,7 +926,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.setEndpointState(StateInitial) } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() @@ -961,10 +937,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { } // Connect connects the endpoint to its peer. Specifying a NIC is optional. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { if addr.Port == 0 { // We don't support connecting to port zero. - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } e.mu.Lock() @@ -981,12 +957,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } if nicID != 0 && nicID != e.BindNICID { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nicID = e.BindNICID default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -1023,7 +999,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { oldPortFlags := e.boundPortFlags - id, btd, err := e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(netProtos, id) if err != nil { r.Release() return err @@ -1031,11 +1007,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) } e.ID = id e.boundBindToDevice = btd + if e.route != nil { + // If the endpoint was already connected then make sure we release the + // previous route. + e.route.Release() + } e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID @@ -1051,20 +1032,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection // to its peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. if state := e.EndpointState(); state != StateBound && state != StateConnected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } e.shutdownFlags |= flags @@ -1084,16 +1065,16 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen is not supported by UDP, it just fails. -func (*endpoint) Listen(int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) @@ -1104,7 +1085,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} @@ -1112,11 +1093,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ return id, bindToDevice, err } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. if e.EndpointState() != StateInitial { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -1140,7 +1121,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // A local unicast address was specified, verify that it's valid. nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicID == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } @@ -1148,7 +1129,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, btd, err := e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(netProtos, id) if err != nil { return err } @@ -1170,7 +1151,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -1186,7 +1167,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -1203,12 +1184,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() if e.EndpointState() != StateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return tcpip.FullAddress{ @@ -1341,7 +1322,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -1397,7 +1378,7 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt default: panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) } - e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt) + e.onICMPError(&tcpip.ErrConnectionRefused{}, errType, errCode, extra, pkt) return } } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 13b72dc88..21a6aa460 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,24 +37,6 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// saveLastError is invoked by stateify. -func (e *endpoint) saveLastError() string { - if e.lastError == nil { - return "" - } - - return e.lastError.String() -} - -// loadLastError is invoked by stateify. -func (e *endpoint) loadLastError(s string) { - if s == "" { - return - } - - e.lastError = tcpip.StringToError(s) -} - // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). @@ -91,6 +73,7 @@ func (e *endpoint) Resume(s *stack.Stack) { defer e.mu.Unlock() e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { @@ -113,7 +96,7 @@ func (e *endpoint) Resume(s *stack.Stack) { netProto = header.IPv6ProtocolNumber } - var err *tcpip.Error + var err tcpip.Error if state == StateConnected { e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { @@ -122,7 +105,7 @@ func (e *endpoint) Resume(s *stack.Stack) { } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } @@ -131,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // pass it to the reservation machinery. id := e.ID e.ID.LocalPort = 0 - e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 49e673d58..705ad1f64 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -69,7 +69,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { } // CreateEndpoint creates a connected UDP endpoint for the session request. -func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { netHdr := r.pkt.Network() route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) if err != nil { @@ -77,7 +77,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { + if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 91420edd3..427fdd0c9 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -54,13 +54,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } @@ -71,7 +71,7 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given udp // packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { h := header.UDP(v) return h.SourcePort(), h.DestinationPort(), nil } @@ -94,13 +94,13 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4e2123fe9..64c5298d3 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -353,7 +353,7 @@ func (c *testContext) cleanup() { func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { c.t.Helper() - var err *tcpip.Error + var err tcpip.Error c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { c.t.Fatal("NewEndpoint failed: ", err) @@ -555,11 +555,11 @@ func TestBindToDeviceOption(t *testing.T) { testActions := []struct { name string setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error + setBindToDeviceError tcpip.Error getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, {"BindToExistent", nicIDPtr(321), nil, 321}, {"UnbindToDevice", nicIDPtr(0), nil, 0}, } @@ -599,7 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe var buf bytes.Buffer res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for data to become available. select { case <-ch: @@ -703,8 +703,11 @@ func TestBindReservedPort(t *testing.T) { t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() - if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) + { + err := ep.Bind(addr) + if _, ok := err.(*tcpip.ErrPortInUse); !ok { + t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) + } } } @@ -716,8 +719,11 @@ func TestBindReservedPort(t *testing.T) { defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint // above, since the endpoint is dual-stack. - if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) + { + err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) + if _, ok := err.(*tcpip.ErrPortInUse); !ok { + t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) + } } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { @@ -806,11 +812,11 @@ func TestV4ReadSelfSource(t *testing.T) { for _, tt := range []struct { name string handleLocal bool - wantErr *tcpip.Error + wantErr tcpip.Error wantInvalidSource uint64 }{ {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, + {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ @@ -959,15 +965,16 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { // testFailingWrite sends a packet of the given test flow into the UDP endpoint // and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { +func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) { c.t.Helper() // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - payload := buffer.View(newPayload()) - _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + var r bytes.Reader + r.Reset(newPayload()) + _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) c.checkEndpointWriteStats(1, epstats, gotErr) @@ -1007,8 +1014,10 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, } } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("Write failed: %s", err) } @@ -1089,7 +1098,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { testWrite(c, unicastV6) // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{}) const want = 1 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) @@ -1110,7 +1119,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { testWrite(c, unicastV4in6) // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) + testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) } func TestV4WriteOnV6Only(t *testing.T) { @@ -1120,7 +1129,7 @@ func TestV4WriteOnV6Only(t *testing.T) { c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) + testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{}) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { @@ -1135,7 +1144,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) + testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) } func TestV6WriteOnConnected(t *testing.T) { @@ -1183,8 +1192,10 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { writeOpts := tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) } @@ -1192,8 +1203,11 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) } - if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused { - c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + { + err := c.ep.LastError() + if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { + c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + } } }) } @@ -2303,21 +2317,21 @@ func TestShutdownWrite(t *testing.T) { t.Fatalf("Shutdown failed: %s", err) } - testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) + testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{}) } -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { + switch err.(type) { case nil: want.PacketsSent.IncrementBy(incr) - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: want.WriteErrors.InvalidArgs.IncrementBy(incr) - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: want.WriteErrors.WriteClosed.IncrementBy(incr) - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: want.SendErrors.NoRoute.IncrementBy(incr) default: want.SendErrors.SendToNetworkFailed.IncrementBy(incr) @@ -2327,11 +2341,11 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE } } -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil, tcpip.ErrWouldBlock: - case tcpip.ErrClosedForReceive: + switch err.(type) { + case nil, *tcpip.ErrWouldBlock: + case *tcpip.ErrClosedForReceive: want.ReadErrors.ReadClosed.IncrementBy(incr) default: c.t.Errorf("Endpoint error missing stats update err %v", err) @@ -2497,31 +2511,54 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } defer ep.Close() - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + var r bytes.Reader + data := []byte{1, 2, 3, 4} to := tcpip.FullAddress{ Addr: test.remoteAddr, Port: 80, } opts := tcpip.WriteOptions{To: &to} - expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { + if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { + return nil + } + return &tcpip.ErrBroadcastDisabled{} + } if !test.requiresBroadcastOpt { expectedErrWithoutBcastOpt = nil } - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) + r.Reset(data) + { + n, err := ep.Write(&r, opts) + if expectedErrWithoutBcastOpt != nil { + if want := expectedErrWithoutBcastOpt(err); want != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) + } + } else if err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) + } } ep.SocketOptions().SetBroadcast(true) - if n, err := ep.Write(data, opts); err != nil { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != nil { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) } ep.SocketOptions().SetBroadcast(false) - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) + r.Reset(data) + { + n, err := ep.Write(&r, opts) + if expectedErrWithoutBcastOpt != nil { + if want := expectedErrWithoutBcastOpt(err); want != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) + } + } else if err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) + } } }) } |