diff options
Diffstat (limited to 'pkg/tcpip')
134 files changed, 2484 insertions, 31582 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD deleted file mode 100644 index 65d4d0cd8..000000000 --- a/pkg/tcpip/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "tcpip", - srcs = [ - "packet_buffer.go", - "packet_buffer_state.go", - "tcpip.go", - "time_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip/buffer", - "//pkg/tcpip/iptables", - "//pkg/waiter", - ], -) - -go_test( - name = "tcpip_test", - size = "small", - srcs = ["tcpip_test.go"], - embed = [":tcpip"], -) diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD deleted file mode 100644 index 78df5a0b1..000000000 --- a/pkg/tcpip/adapters/gonet/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gonet", - srcs = ["gonet.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) - -go_test( - name = "gonet_test", - size = "small", - srcs = ["gonet_test.go"], - embed = [":gonet"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@org_golang_x_net//nettest:go_default_library", - ], -) diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go deleted file mode 100644 index cd6ce930a..000000000 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ /dev/null @@ -1,722 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gonet provides a Go net package compatible wrapper for a tcpip stack. -package gonet - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - errCanceled = errors.New("operation canceled") - errWouldBlock = errors.New("operation would block") -) - -// timeoutError is how the net package reports timeouts. -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - -// A Listener is a wrapper around a tcpip endpoint that implements -// net.Listener. -type Listener struct { - stack *stack.Stack - ep tcpip.Endpoint - wq *waiter.Queue - cancel chan struct{} -} - -// NewListener creates a new Listener. -func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) { - // Create TCP endpoint, bind it, then start listening. - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) - if err != nil { - return nil, errors.New(err.String()) - } - - if err := ep.Bind(addr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "tcp", - Addr: fullToTCPAddr(addr), - Err: errors.New(err.String()), - } - } - - if err := ep.Listen(10); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "listen", - Net: "tcp", - Addr: fullToTCPAddr(addr), - Err: errors.New(err.String()), - } - } - - return &Listener{ - stack: s, - ep: ep, - wq: &wq, - cancel: make(chan struct{}), - }, nil -} - -// Close implements net.Listener.Close. -func (l *Listener) Close() error { - l.ep.Close() - return nil -} - -// Shutdown stops the HTTP server. -func (l *Listener) Shutdown() { - l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - close(l.cancel) // broadcast cancellation -} - -// Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { - a, err := l.ep.GetLocalAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -type deadlineTimer struct { - // mu protects the fields below. - mu sync.Mutex - - readTimer *time.Timer - readCancelCh chan struct{} - writeTimer *time.Timer - writeCancelCh chan struct{} -} - -func (d *deadlineTimer) init() { - d.readCancelCh = make(chan struct{}) - d.writeCancelCh = make(chan struct{}) -} - -func (d *deadlineTimer) readCancel() <-chan struct{} { - d.mu.Lock() - c := d.readCancelCh - d.mu.Unlock() - return c -} -func (d *deadlineTimer) writeCancel() <-chan struct{} { - d.mu.Lock() - c := d.writeCancelCh - d.mu.Unlock() - return c -} - -// setDeadline contains the shared logic for setting a deadline. -// -// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and -// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and -// deadlineTimer.writeTimer. -// -// setDeadline must only be called while holding d.mu. -func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { - if *timer != nil && !(*timer).Stop() { - *cancelCh = make(chan struct{}) - } - - // Create a new channel if we already closed it due to setting an already - // expired time. We won't race with the timer because we already handled - // that above. - select { - case <-*cancelCh: - *cancelCh = make(chan struct{}) - default: - } - - // "A zero value for t means I/O operations will not time out." - // - net.Conn.SetDeadline - if t.IsZero() { - return - } - - timeout := t.Sub(time.Now()) - if timeout <= 0 { - close(*cancelCh) - return - } - - // Timer.Stop returns whether or not the AfterFunc has started, but - // does not indicate whether or not it has completed. Make a copy of - // the cancel channel to prevent this code from racing with the next - // call of setDeadline replacing *cancelCh. - ch := *cancelCh - *timer = time.AfterFunc(timeout, func() { - close(ch) - }) -} - -// SetReadDeadline implements net.Conn.SetReadDeadline and -// net.PacketConn.SetReadDeadline. -func (d *deadlineTimer) SetReadDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.readCancelCh, &d.readTimer, t) - d.mu.Unlock() - return nil -} - -// SetWriteDeadline implements net.Conn.SetWriteDeadline and -// net.PacketConn.SetWriteDeadline. -func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) - d.mu.Unlock() - return nil -} - -// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. -func (d *deadlineTimer) SetDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.readCancelCh, &d.readTimer, t) - d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) - d.mu.Unlock() - return nil -} - -// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn -// interface. -type Conn struct { - deadlineTimer - - wq *waiter.Queue - ep tcpip.Endpoint - - // readMu serializes reads and implicitly protects read. - // - // Lock ordering: - // If both readMu and deadlineTimer.mu are to be used in a single - // request, readMu must be acquired before deadlineTimer.mu. - readMu sync.Mutex - - // read contains bytes that have been read from the endpoint, - // but haven't yet been returned. - read buffer.View -} - -// NewConn creates a new Conn. -func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { - c := &Conn{ - wq: wq, - ep: ep, - } - c.deadlineTimer.init() - return c -} - -// Accept implements net.Conn.Accept. -func (l *Listener) Accept() (net.Conn, error) { - n, wq, err := l.ep.Accept() - - if err == tcpip.ErrWouldBlock { - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - l.wq.EventRegister(&waitEntry, waiter.EventIn) - defer l.wq.EventUnregister(&waitEntry) - - for { - n, wq, err = l.ep.Accept() - - if err != tcpip.ErrWouldBlock { - break - } - - select { - case <-l.cancel: - return nil, errCanceled - case <-notifyCh: - } - } - } - - if err != nil { - return nil, &net.OpError{ - Op: "accept", - Net: "tcp", - Addr: l.Addr(), - Err: errors.New(err.String()), - } - } - - return NewConn(wq, n), nil -} - -type opErrorer interface { - newOpError(op string, err error) *net.OpError -} - -// commonRead implements the common logic between net.Conn.Read and -// net.PacketConn.ReadFrom. -func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) { - select { - case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) - default: - } - - read, _, err := ep.Read(addr) - - if err == tcpip.ErrWouldBlock { - if dontWait { - return nil, errWouldBlock - } - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - for { - read, _, err = ep.Read(addr) - if err != tcpip.ErrWouldBlock { - break - } - select { - case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) - case <-notifyCh: - } - } - } - - if err == tcpip.ErrClosedForReceive { - return nil, io.EOF - } - - if err != nil { - return nil, errorer.newOpError("read", errors.New(err.String())) - } - - return read, nil -} - -// Read implements net.Conn.Read. -func (c *Conn) Read(b []byte) (int, error) { - c.readMu.Lock() - defer c.readMu.Unlock() - - deadline := c.readCancel() - - numRead := 0 - for numRead != len(b) { - if len(c.read) == 0 { - var err error - c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0) - if err != nil { - if numRead != 0 { - return numRead, nil - } - return numRead, err - } - } - n := copy(b[numRead:], c.read) - c.read.TrimFront(n) - numRead += n - if len(c.read) == 0 { - c.read = nil - } - } - return numRead, nil -} - -// Write implements net.Conn.Write. -func (c *Conn) Write(b []byte) (int, error) { - deadline := c.writeCancel() - - // Check if deadlineTimer has already expired. - select { - case <-deadline: - return 0, c.newOpError("write", &timeoutError{}) - default: - } - - v := buffer.NewViewFromBytes(b) - - // We must handle two soft failure conditions simultaneously: - // 1. Write may write nothing and return tcpip.ErrWouldBlock. - // If this happens, we need to register for notifications if we have - // not already and wait to try again. - // 2. Write may write fewer than the full number of bytes and return - // without error. In this case we need to try writing the remaining - // bytes again. I do not need to register for notifications. - // - // What is more, these two soft failure conditions can be interspersed. - // There is no guarantee that all of the condition #1s will occur before - // all of the condition #2s or visa-versa. - var ( - err *tcpip.Error - nbytes int - reg bool - notifyCh chan struct{} - ) - for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { - if err == tcpip.ErrWouldBlock { - if !reg { - // Only register once. - reg = true - - // Create wait queue entry that notifies a channel. - var waitEntry waiter.Entry - waitEntry, notifyCh = waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) - } else { - // Don't wait immediately after registration in case more data - // became available between when we last checked and when we setup - // the notification. - select { - case <-deadline: - return nbytes, c.newOpError("write", &timeoutError{}) - case <-notifyCh: - } - } - } - - var n int64 - var resCh <-chan struct{} - n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - - if resCh != nil { - select { - case <-deadline: - return nbytes, c.newOpError("write", &timeoutError{}) - case <-resCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } - } - - if err == nil { - return nbytes, nil - } - - return nbytes, c.newOpError("write", errors.New(err.String())) -} - -// Close implements net.Conn.Close. -func (c *Conn) Close() error { - c.ep.Close() - return nil -} - -// CloseRead shuts down the reading side of the TCP connection. Most callers -// should just use Close. -// -// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. -func (c *Conn) CloseRead() error { - if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { - return c.newOpError("close", errors.New(terr.String())) - } - return nil -} - -// CloseWrite shuts down the writing side of the TCP connection. Most callers -// should just use Close. -// -// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. -func (c *Conn) CloseWrite() error { - if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { - return c.newOpError("close", errors.New(terr.String())) - } - return nil -} - -// LocalAddr implements net.Conn.LocalAddr. -func (c *Conn) LocalAddr() net.Addr { - a, err := c.ep.GetLocalAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -// RemoteAddr implements net.Conn.RemoteAddr. -func (c *Conn) RemoteAddr() net.Addr { - a, err := c.ep.GetRemoteAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -func (c *Conn) newOpError(op string, err error) *net.OpError { - return &net.OpError{ - Op: op, - Net: "tcp", - Source: c.LocalAddr(), - Addr: c.RemoteAddr(), - Err: err, - } -} - -func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr { - return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} -} - -func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { - return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} -} - -// DialTCP creates a new TCP Conn connected to the specified address. -func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { - return DialContextTCP(context.Background(), s, addr, network) -} - -// DialContextTCP creates a new TCP Conn connected to the specified address -// with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { - // Create TCP endpoint, then connect. - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) - if err != nil { - return nil, errors.New(err.String()) - } - - // Create wait queue entry that notifies a channel. - // - // We do this unconditionally as Connect will always return an error. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { - select { - case <-ctx.Done(): - ep.Close() - return nil, ctx.Err() - case <-notifyCh: - } - - err = ep.GetSockOpt(tcpip.ErrorOption{}) - } - if err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "connect", - Net: "tcp", - Addr: fullToTCPAddr(addr), - Err: errors.New(err.String()), - } - } - - return NewConn(&wq, ep), nil -} - -// A PacketConn is a wrapper around a tcpip endpoint that implements -// net.PacketConn. -type PacketConn struct { - deadlineTimer - - stack *stack.Stack - ep tcpip.Endpoint - wq *waiter.Queue -} - -// DialUDP creates a new PacketConn. -// -// If laddr is nil, a local address is automatically chosen. -// -// If raddr is nil, the PacketConn is left unconnected. -func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) - if err != nil { - return nil, errors.New(err.String()) - } - - if laddr != nil { - if err := ep.Bind(*laddr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "udp", - Addr: fullToUDPAddr(*laddr), - Err: errors.New(err.String()), - } - } - } - - c := PacketConn{ - stack: s, - ep: ep, - wq: &wq, - } - c.deadlineTimer.init() - - if raddr != nil { - if err := c.ep.Connect(*raddr); err != nil { - c.ep.Close() - return nil, &net.OpError{ - Op: "connect", - Net: "udp", - Addr: fullToUDPAddr(*raddr), - Err: errors.New(err.String()), - } - } - } - - return &c, nil -} - -func (c *PacketConn) newOpError(op string, err error) *net.OpError { - return c.newRemoteOpError(op, nil, err) -} - -func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { - return &net.OpError{ - Op: op, - Net: "udp", - Source: c.LocalAddr(), - Addr: remote, - Err: err, - } -} - -// RemoteAddr implements net.Conn.RemoteAddr. -func (c *PacketConn) RemoteAddr() net.Addr { - a, err := c.ep.GetRemoteAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -// Read implements net.Conn.Read -func (c *PacketConn) Read(b []byte) (int, error) { - bytesRead, _, err := c.ReadFrom(b) - return bytesRead, err -} - -// ReadFrom implements net.PacketConn.ReadFrom. -func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - deadline := c.readCancel() - - var addr tcpip.FullAddress - read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false) - if err != nil { - return 0, nil, err - } - - return copy(b, read), fullToUDPAddr(addr), nil -} - -func (c *PacketConn) Write(b []byte) (int, error) { - return c.WriteTo(b, nil) -} - -// WriteTo implements net.PacketConn.WriteTo. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - deadline := c.writeCancel() - - // Check if deadline has already expired. - select { - case <-deadline: - return 0, c.newRemoteOpError("write", addr, &timeoutError{}) - default: - } - - // If we're being called by Write, there is no addr - wopts := tcpip.WriteOptions{} - if addr != nil { - ua := addr.(*net.UDPAddr) - wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} - } - - v := buffer.NewView(len(b)) - copy(v, b) - - n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts) - if resCh != nil { - select { - case <-deadline: - return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) - case <-resCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) - } - - if err == tcpip.ErrWouldBlock { - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) - for { - select { - case <-deadline: - return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) - case <-notifyCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) - if err != tcpip.ErrWouldBlock { - break - } - } - } - - if err == nil { - return int(n), nil - } - - return int(n), c.newRemoteOpError("write", addr, errors.New(err.String())) -} - -// Close implements net.PacketConn.Close. -func (c *PacketConn) Close() error { - c.ep.Close() - return nil -} - -// LocalAddr implements net.PacketConn.LocalAddr. -func (c *PacketConn) LocalAddr() net.Addr { - a, err := c.ep.GetLocalAddress() - if err != nil { - return nil - } - return fullToUDPAddr(a) -} diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go deleted file mode 100644 index ee077ae83..000000000 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ /dev/null @@ -1,683 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gonet - -import ( - "context" - "fmt" - "io" - "net" - "reflect" - "strings" - "testing" - "time" - - "golang.org/x/net/nettest" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - NICID = 1 -) - -func TestTimeouts(t *testing.T) { - nc := NewConn(nil, nil) - dlfs := []struct { - name string - f func(time.Time) error - }{ - {"SetDeadline", nc.SetDeadline}, - {"SetReadDeadline", nc.SetReadDeadline}, - {"SetWriteDeadline", nc.SetWriteDeadline}, - } - - for _, dlf := range dlfs { - if err := dlf.f(time.Time{}); err != nil { - t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil) - } - } -} - -func newLoopbackStack() (*stack.Stack, *tcpip.Error) { - // Create the stack and add a NIC. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, - }) - - if err := s.CreateNIC(NICID, loopback.New()); err != nil { - return nil, err - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - // IPv4 - { - Destination: header.IPv4EmptySubnet, - NIC: NICID, - }, - - // IPv6 - { - Destination: header.IPv6EmptySubnet, - NIC: NICID, - }, - }) - - return s, nil -} - -type testConnection struct { - wq *waiter.Queue - e *waiter.Entry - ch chan struct{} - ep tcpip.Endpoint -} - -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { - wq := &waiter.Queue{} - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - - entry, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&entry, waiter.EventOut) - - err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { - <-ch - err = ep.GetSockOpt(tcpip.ErrorOption{}) - } - if err != nil { - return nil, err - } - - wq.EventUnregister(&entry) - wq.EventRegister(&entry, waiter.EventIn) - - return &testConnection{wq, &entry, ch, ep}, nil -} - -func (c *testConnection) close() { - c.wq.EventUnregister(c.e) - c.ep.Close() -} - -// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock. -func TestCloseReader(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := NewListener(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Fatalf("l.Accept() = %v", err) - } - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - if n != 0 || err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when -// using tcp.Forwarder. -func TestCloseReaderWithForwarder(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - done := make(chan struct{}) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewConn(&wq, ep) - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) - } - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestCloseRead(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewConn(tc.wq, tc.ep) - - if err := c.CloseRead(); err != nil { - t.Errorf("c.CloseRead() = %v", err) - } - - buf := make([]byte, 256) - if n, err := c.Read(buf); err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err) - } - - if n, err := c.Write([]byte("abc123")); n != 6 || err != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err) - } -} - -func TestCloseWrite(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewConn(&wq, ep) - - n, e := c.Read(make([]byte, 256)) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e) - } - - if n, e = c.Write([]byte("abc123")); n != 6 || e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewConn(tc.wq, tc.ep) - - if err := c.CloseWrite(); err != nil { - t.Errorf("c.CloseWrite() = %v", err) - } - - buf := make([]byte, 256) - n, err := c.Read(buf) - if err != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err) - } - - n, err = c.Write([]byte("abc123")) - got, ok := err.(*net.OpError) - want := "endpoint is closed for send" - if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) { - t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want) - } -} - -func TestUDPForwarder(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - done := make(chan struct{}) - fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - - c := NewConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil { - t.Errorf("c.Read() = %v", e) - } - - if _, e := c.Write(buf[:n]); e != nil { - t.Errorf("c.Write() = %v", e) - } - }) - s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - - c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - sent := "abc123" - sendAddr := fullToUDPAddr(addr1) - if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - - buf := make([]byte, 256) - n, recvAddr, err := c2.ReadFrom(buf) - if err != nil || recvAddr.String() != sendAddr.String() { - t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil) - } -} - -// TestDeadlineChange tests that changing the deadline affects currently blocked reads. -func TestDeadlineChange(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := NewListener(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Fatalf("l.Accept() = %v", err) - } - - c.SetDeadline(time.Now().Add(time.Minute)) - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.SetDeadline(time.Now().Add(time.Millisecond * 10)) - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := "i/o timeout" - if n != 0 || !ok || got.Err == nil || got.Err.Error() != want { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(time.Millisecond * 500): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 4):", err) - } - c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - sendAddr := fullToUDPAddr(addr2) - if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, recvAddr, err := c2.ReadFrom(recv) - if err != nil || n != len(recv) { - t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) { - t.Errorf("got recvAddr = %v, want = %v", recvAddr, want) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func TestConnectedPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 4):", err) - } - c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, err := c1.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func makePipe() (c1, c2 net.Conn, stop func(), err error) { - s, e := newLoopbackStack() - if e != nil { - return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e) - } - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - l, err := NewListener(s, addr, ipv4.ProtocolNumber) - if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) - } - - c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) - if err != nil { - l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) - } - - c2, err = l.Accept() - if err != nil { - l.Close() - c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) - } - - stop = func() { - c1.Close() - c2.Close() - } - - if err := l.Close(); err != nil { - stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) - } - - return c1, c2, stop, nil -} - -func TestTCPConnTransfer(t *testing.T) { - c1, c2, _, err := makePipe() - if err != nil { - t.Fatal(err) - } - defer func() { - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } - }() - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - const sent = "abc123" - - tests := []struct { - name string - c1 net.Conn - c2 net.Conn - }{ - {"connected to accepted", c1, c2}, - {"accepted to connected", c2, c1}, - } - - for _, test := range tests { - if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil) - continue - } - - recv := make([]byte, len(sent)) - n, err := test.c2.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil) - continue - } - - if recv := string(recv); recv != sent { - t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent) - } - } -} - -func TestTCPDialError(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - - _, err := DialTCP(s, addr, ipv4.ProtocolNumber) - got, ok := err.(*net.OpError) - want := tcpip.ErrNoRoute - if !ok || got.Err.Error() != want.String() { - t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute) - } -} - -func TestDialContextTCPCanceled(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled) - } -} - -func TestDialContextTCPTimeout(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - time.Sleep(time.Second) - r.Complete(true) - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - ctx := context.Background() - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) - defer cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded) - } -} - -func TestNetTest(t *testing.T) { - nettest.TestConn(t, makePipe) -} diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD deleted file mode 100644 index d6c31bfa2..000000000 --- a/pkg/tcpip/buffer/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "buffer", - srcs = [ - "prependable.go", - "view.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer", - visibility = ["//visibility:public"], -) - -go_test( - name = "buffer_test", - size = "small", - srcs = ["view_test.go"], - embed = [":buffer"], -) diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go new file mode 100755 index 000000000..41a74ea97 --- /dev/null +++ b/pkg/tcpip/buffer/buffer_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package buffer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *VectorisedView) beforeSave() {} +func (x *VectorisedView) save(m state.Map) { + x.beforeSave() + m.Save("views", &x.views) + m.Save("size", &x.size) +} + +func (x *VectorisedView) afterLoad() {} +func (x *VectorisedView) load(m state.Map) { + m.Load("views", &x.views) + m.Load("size", &x.size) +} + +func init() { + state.Register("buffer.VectorisedView", (*VectorisedView)(nil), state.Fns{Save: (*VectorisedView).save, Load: (*VectorisedView).load}) +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go deleted file mode 100644 index ebc3a17b7..000000000 --- a/pkg/tcpip/buffer/view_test.go +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package buffer_test contains tests for the VectorisedView type. -package buffer - -import ( - "reflect" - "testing" -) - -// copy returns a deep-copy of the vectorised view. -func (vv VectorisedView) copy() VectorisedView { - uu := VectorisedView{ - views: make([]View, 0, len(vv.views)), - size: vv.size, - } - for _, v := range vv.views { - uu.views = append(uu.views, append(View(nil), v...)) - } - return uu -} - -// vv is an helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) VectorisedView { - views := make([]View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return NewVectorisedView(size, views) -} - -var capLengthTestCases = []struct { - comment string - in VectorisedView - length int - want VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Case spanning across two Views", - in: vv(4, "123", "4"), - length: 2, - want: vv(2, "12"), - }, - { - comment: "Corner case with negative length", - in: vv(1, "1"), - length: -1, - want: vv(0), - }, - { - comment: "Corner case with length = 0", - in: vv(3, "12", "3"), - length: 0, - want: vv(0), - }, - { - comment: "Corner case with length = size", - in: vv(1, "1"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Corner case with length > size", - in: vv(1, "1"), - length: 2, - want: vv(1, "1"), - }, -} - -func TestCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - orig := c.in.copy() - c.in.CapLength(c.length) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v", - c.comment, c.length, orig, c.in, c.want) - } - } -} - -var trimFrontTestCases = []struct { - comment string - in VectorisedView - count int - want VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case where we trim an entire View", - in: vv(2, "1", "2"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: vv(1, "3"), - }, - { - comment: "Corner case with negative count", - in: vv(1, "1"), - count: -1, - want: vv(1, "1"), - }, - { - comment: " Corner case with count = 0", - in: vv(1, "1"), - count: 0, - want: vv(1, "1"), - }, - { - comment: "Corner case with count = size", - in: vv(1, "1"), - count: 1, - want: vv(0), - }, - { - comment: "Corner case with count > size", - in: vv(1, "1"), - count: 2, - want: vv(0), - }, -} - -func TestTrimFront(t *testing.T) { - for _, c := range trimFrontTestCases { - orig := c.in.copy() - c.in.TrimFront(c.count) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v", - c.comment, c.count, orig, c.in, c.want) - } - } -} - -var toViewCases = []struct { - comment string - in VectorisedView - want View -}{ - { - comment: "Simple case", - in: vv(2, "12"), - want: []byte("12"), - }, - { - comment: "Case with multiple views", - in: vv(2, "1", "2"), - want: []byte("12"), - }, - { - comment: "Empty case", - in: vv(0), - want: []byte(""), - }, -} - -func TestToView(t *testing.T) { - for _, c := range toViewCases { - got := c.in.ToView() - if !reflect.DeepEqual(got, c.want) { - t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v", - c.comment, c.in, got, c.want) - } - } -} - -var toCloneCases = []struct { - comment string - inView VectorisedView - inBuffer []View -}{ - { - comment: "Simple case", - inView: vv(1, "1"), - inBuffer: make([]View, 1), - }, - { - comment: "Case with multiple views", - inView: vv(2, "1", "2"), - inBuffer: make([]View, 2), - }, - { - comment: "Case with buffer too small", - inView: vv(2, "1", "2"), - inBuffer: make([]View, 1), - }, - { - comment: "Case with buffer larger than needed", - inView: vv(1, "1"), - inBuffer: make([]View, 2), - }, - { - comment: "Case with nil buffer", - inView: vv(1, "1"), - inBuffer: nil, - }, -} - -func TestToClone(t *testing.T) { - for _, c := range toCloneCases { - t.Run(c.comment, func(t *testing.T) { - got := c.inView.Clone(c.inBuffer) - if !reflect.DeepEqual(got, c.inView) { - t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v", - c.inView, c.inBuffer, got, c.inView) - } - }) - } -} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD deleted file mode 100644 index b6fa6fc37..000000000 --- a/pkg/tcpip/checker/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "checker", - testonly = 1, - srcs = ["checker.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/checker", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go deleted file mode 100644 index 2f15bf1f1..000000000 --- a/pkg/tcpip/checker/checker.go +++ /dev/null @@ -1,756 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package checker provides helper functions to check networking packets for -// validity. -package checker - -import ( - "encoding/binary" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -// NetworkChecker is a function to check a property of a network packet. -type NetworkChecker func(*testing.T, []header.Network) - -// TransportChecker is a function to check a property of a transport packet. -type TransportChecker func(*testing.T, header.Transport) - -// IPv4 checks the validity and properties of the given IPv4 packet. It is -// expected to be used in conjunction with other network checkers for specific -// properties. For example, to check the source and destination address, one -// would call: -// -// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) -func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv4 := header.IPv4(b) - - if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") - } - - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) - } - - for _, f := range checkers { - f(t, []header.Network{ipv4}) - } - if t.Failed() { - t.FailNow() - } -} - -// IPv6 checks the validity and properties of the given IPv6 packet. The usage -// is similar to IPv4. -func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") - } - - for _, f := range checkers { - f(t, []header.Network{ipv6}) - } - if t.Failed() { - t.FailNow() - } -} - -// SrcAddr creates a checker that checks the source address. -func SrcAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].SourceAddress(); a != addr { - t.Errorf("Bad source address, got %v, want %v", a, addr) - } - } -} - -// DstAddr creates a checker that checks the destination address. -func DstAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].DestinationAddress(); a != addr { - t.Errorf("Bad destination address, got %v, want %v", a, addr) - } - } -} - -// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). -func TTL(ttl uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - var v uint8 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TTL() - case header.IPv6: - v = ip.HopLimit() - } - if v != ttl { - t.Fatalf("Bad TTL, got %v, want %v", v, ttl) - } - } -} - -// PayloadLen creates a checker that checks the payload length. -func PayloadLen(plen int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if l := len(h[0].Payload()); l != plen { - t.Errorf("Bad payload length, got %v, want %v", l, plen) - } - } -} - -// FragmentOffset creates a checker that checks the FragmentOffset field. -func FragmentOffset(offset uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this of IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got %v, want %v", v, offset) - } - } - } -} - -// FragmentFlags creates a checker that checks the fragment flags field. -func FragmentFlags(flags uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this of IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got %v, want %v", v, flags) - } - } - } -} - -// TOS creates a checker that checks the TOS field. -func TOS(tos uint8, label uint32) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) - } - } -} - -// Raw creates a checker that checks the bytes of payload. -// The checker always checks the payload of the last network header. -// For instance, in case of IPv6 fragments, the payload that will be checked -// is the one containing the actual data that the packet is carrying, without -// the bytes added by the IPv6 fragmentation. -func Raw(want []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// IPv6Fragment creates a checker that validates an IPv6 fragment. -func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) - } - - ipv6Frag := header.IPv6Fragment(h[0].Payload()) - if !ipv6Frag.IsValid() { - t.Error("Not a valid IPv6 fragment") - } - - for _, f := range checkers { - f(t, []header.Network{h[0], ipv6Frag}) - } - if t.Failed() { - t.FailNow() - } - } -} - -// TCP creates a checker that checks that the transport protocol is TCP and -// potentially additional transport header fields. -func TCP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - first := h[0] - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) - } - - // Verify the checksum. - tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) - } - - // Run the transport checkers. - for _, f := range checkers { - f(t, tcp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// UDP creates a checker that checks that the transport protocol is UDP and -// potentially additional transport header fields. -func UDP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) - } - - udp := header.UDP(last.Payload()) - for _, f := range checkers { - f(t, udp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// SrcPort creates a checker that checks the source port. -func SrcPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got %v, want %v", p, port) - } - } -} - -// DstPort creates a checker that checks the destination port. -func DstPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got %v, want %v", p, port) - } - } -} - -// SeqNum creates a checker that checks the sequence number. -func SeqNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got %v, want %v", s, seq) - } - } -} - -// AckNum creates a checker that checks the ack number. -func AckNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got %v, want %v", s, seq) - } - } -} - -// Window creates a checker that checks the tcp window. -func Window(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) - } - } -} - -// TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if f := tcp.Flags(); f != flags { - t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) - } - } -} - -// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the -// given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) - } - } -} - -// TCPSynOptions creates a checker that checks the presence of TCP options in -// SYN segments. -// -// If wndscale is negative, the window scale option must not be present. -func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := tcp.Options() - limit := len(opts) - foundMSS := false - foundWS := false - foundTS := false - foundSACKPermitted := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionMSS: - v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) - if wantOpts.MSS != v { - t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) - } - foundMSS = true - i += 4 - case header.TCPOptionWS: - if wantOpts.WS < 0 { - t.Error("WS present when it shouldn't be") - } - v := int(opts[i+2]) - if v != wantOpts.WS { - t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) - } - foundWS = true - i += 3 - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = uint32(0) - if tcp.Flags()&header.TCPFlagAck != 0 { - // If the syn is an SYN-ACK then read - // the tsEcr value as well. - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - } - foundTS = true - i += 10 - case header.TCPOptionSACKPermitted: - if i+1 >= limit { - t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) - } - if opts[i+1] != 2 { - t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) - } - foundSACKPermitted = true - i += 2 - - default: - i += int(opts[i+1]) - } - } - - if !foundMSS { - t.Errorf("MSS option not found. Options: %x", opts) - } - - if !foundWS && wantOpts.WS >= 0 { - t.Errorf("WS option not found. Options: %x", opts) - } - if wantOpts.TS && !foundTS { - t.Errorf("TS option not found. Options: %x", opts) - } - if foundTS && tsVal == 0 { - t.Error("TS option specified but the timestamp value is zero") - } - if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) - } - if wantOpts.SACKPermitted && !foundSACKPermitted { - t.Errorf("SACKPermitted option not found. Options: %x", opts) - } - } -} - -// TCPTimestampChecker creates a checker that validates that a TCP segment has a -// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and -// wantTSEcr values with those in the TCP segment (if present). -// -// If wantTSVal or wantTSEcr is zero then the corresponding comparison is -// skipped. -func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := []byte(tcp.Options()) - limit := len(opts) - foundTS := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - foundTS = true - i += 10 - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - return - } - l := int(opts[i+1]) - if i < 2 || i+l > limit { - return - } - i += l - } - } - - if wantTS != foundTS { - t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) - } - if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) - } - if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) - } - } -} - -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not -// contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - -// TCPSACKBlockChecker creates a checker that verifies that the segment does -// contain the specified SACK blocks in the TCP options. -func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - var gotSACKBlocks []header.SACKBlock - - opts := []byte(tcp.Options()) - limit := len(opts) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionSACK: - if i+2 > limit { - // Malformed SACK block. - t.Errorf("malformed SACK option in options: %v", opts) - } - sackOptionLen := int(opts[i+1]) - if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { - // Malformed SACK block. - t.Errorf("malformed SACK option length in options: %v", opts) - } - numBlocks := sackOptionLen / 8 - for j := 0; j < numBlocks; j++ { - start := binary.BigEndian.Uint32(opts[i+2+j*8:]) - end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) - gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ - Start: seqnum.Value(start), - End: seqnum.Value(end), - }) - } - i += sackOptionLen - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - break - } - l := int(opts[i+1]) - if l < 2 || i+l > limit { - break - } - i += l - } - } - - if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) - } - } -} - -// Payload creates a checker that checks the payload. -func Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - if got := h.Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and -// potentially additional ICMPv4 header fields. -func ICMPv4(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) - } - - icmp := header.ICMPv4(last.Payload()) - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv4Type creates a checker that checks the ICMPv4 Type field. -func ICMPv4Type(want header.ICMPv4Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) - } - if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) - } - } -} - -// ICMPv4Code creates a checker that checks the ICMPv4 Code field. -func ICMPv4Code(want byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) - } - if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) - } - } -} - -// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and -// potentially additional ICMPv6 header fields. -// -// ICMPv6 will validate the checksum field before calling checkers. -func ICMPv6(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) - } - - icmp := header.ICMPv6(last.Payload()) - if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { - t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv6Type creates a checker that checks the ICMPv6 Type field. -func ICMPv6Type(want header.ICMPv6Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) - } - if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) - } - } -} - -// ICMPv6Code creates a checker that checks the ICMPv6 Code field. -func ICMPv6Code(want byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) - } - if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) - } - } -} - -// NDP creates a checker that checks that the packet contains a valid NDP -// message for type of ty, with potentially additional checks specified by -// checkers. -// -// checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDP message as far as the size of the message (minSize) is concerned. The -// values within the message are up to checkers to validate. -func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // Check normal ICMPv6 first. - ICMPv6( - ICMPv6Type(msgType), - ICMPv6Code(0))(t, h) - - last := h[len(h)-1] - - icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.NDPPayload()); got < minSize { - t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// NDPNS creates a checker that checks that the packet contains a valid NDP -// Neighbor Solicitation message (as per the raw wire format), with potentially -// additional checks specified by checkers. -// -// checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNS message as far as the size of the messages concerned. The values within -// the message are up to checkers to validate. -func NDPNS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) -} - -// NDPNSTargetAddress creates a checker that checks the Target Address field of -// a header.NDPNeighborSolicit. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSTargetAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) - - if got := ns.TargetAddress(); got != want { - t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want) - } - } -} diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD deleted file mode 100644 index e648efa71..000000000 --- a/pkg/tcpip/hash/jenkins/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "jenkins", - srcs = ["jenkins.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", - visibility = ["//visibility:public"], -) - -go_test( - name = "jenkins_test", - size = "small", - srcs = [ - "jenkins_test.go", - ], - embed = [":jenkins"], -) diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go new file mode 100755 index 000000000..310f0ee6d --- /dev/null +++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package jenkins + diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go deleted file mode 100644 index 4c78b5808..000000000 --- a/pkg/tcpip/hash/jenkins/jenkins_test.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package jenkins - -import ( - "bytes" - "encoding/binary" - "hash" - "hash/fnv" - "math" - "testing" -) - -func TestGolden32(t *testing.T) { - var golden32 = []struct { - out []byte - in string - }{ - {[]byte{0x00, 0x00, 0x00, 0x00}, ""}, - {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"}, - {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"}, - {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"}, - } - - hash := New32() - - for _, g := range golden32 { - hash.Reset() - done, error := hash.Write([]byte(g.in)) - if error != nil { - t.Fatalf("write error: %s", error) - } - if done != len(g.in) { - t.Fatalf("wrote only %d out of %d bytes", done, len(g.in)) - } - if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) { - t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out) - } - } -} - -func TestIntegrity32(t *testing.T) { - data := []byte{'1', '2', 3, 4, 5} - - h := New32() - h.Write(data) - sum := h.Sum(nil) - - if size := h.Size(); size != len(sum) { - t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum)) - } - - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data[:2]) - h.Write(data[2:]) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a) - } - - sum32 := h.(hash.Hash32).Sum32() - if sum32 != binary.BigEndian.Uint32(sum) { - t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32) - } -} - -func BenchmarkJenkins32KB(b *testing.B) { - h := New32() - - b.SetBytes(1024) - data := make([]byte, 1024) - for i := range data { - data[i] = byte(i) - } - in := make([]byte, 0, h.Size()) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - h.Reset() - h.Write(data) - h.Sum(in) - } -} - -func BenchmarkFnv32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - - h := fnv.New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - c := 0 - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - if c == 0 { - b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N) - } - c++ - } - } - if c > 0 { - b.Logf("Unbalanced buckets: %d", c) - } - } -} - -func BenchmarkSum32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := Sum32(0) - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} - -func BenchmarkNew32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD deleted file mode 100644 index 8392cb9e5..000000000 --- a/pkg/tcpip/header/BUILD +++ /dev/null @@ -1,62 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "header", - srcs = [ - "arp.go", - "checksum.go", - "eth.go", - "gue.go", - "icmpv4.go", - "icmpv6.go", - "interfaces.go", - "ipv4.go", - "ipv6.go", - "ipv6_fragment.go", - "ndp_neighbor_advert.go", - "ndp_neighbor_solicit.go", - "ndp_options.go", - "ndp_router_advert.go", - "tcp.go", - "udp.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/header", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/seqnum", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "header_x_test", - size = "small", - srcs = [ - "checksum_test.go", - "ipversion_test.go", - "tcp_test.go", - ], - deps = [ - ":header", - "//pkg/tcpip/buffer", - ], -) - -go_test( - name = "header_test", - size = "small", - srcs = [ - "eth_test.go", - "ndp_test.go", - ], - embed = [":header"], - deps = [ - "//pkg/tcpip", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go deleted file mode 100644 index 86b466c1c..000000000 --- a/pkg/tcpip/header/checksum_test.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package header provides the implementation of the encoding and decoding of -// network protocol headers. -package header_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestChecksumVVWithOffset(t *testing.T) { - testCases := []struct { - name string - vv buffer.VectorisedView - off, size int - initial uint16 - want uint16 - }{ - { - name: "empty", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 0, - want: 0, - }, - { - name: "OneView", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 5, - want: 1294, - }, - { - name: "TwoViews", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 0, - size: 11, - want: 33819, - }, - { - name: "TwoViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 1, - size: 11, - want: 33819, - }, - { - name: "ThreeViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 7, - size: 11, - want: 33819, - }, - { - name: "ThreeViewsWithInitial", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), - }), - initial: 77, - off: 7, - size: 11, - want: 33896, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { - t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) - } - v := tc.vv.ToView() - v.TrimFront(tc.off) - v.CapLength(tc.size) - if got, want := header.Checksum(v, tc.initial), tc.want; got != want { - t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) - } - }) - } -} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go deleted file mode 100644 index 6634c90f5..000000000 --- a/pkg/tcpip/header/eth_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -func TestIsValidUnicastEthernetAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.LinkAddress - expected bool - }{ - { - "Nil", - tcpip.LinkAddress([]byte(nil)), - false, - }, - { - "Empty", - tcpip.LinkAddress(""), - false, - }, - { - "InvalidLength", - tcpip.LinkAddress("\x01\x02\x03"), - false, - }, - { - "Unspecified", - unspecifiedEthernetAddress, - false, - }, - { - "Multicast", - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - false, - }, - { - "Valid", - tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), - true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected { - t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected) - } - }) - } -} diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go new file mode 100755 index 000000000..4fcb89452 --- /dev/null +++ b/pkg/tcpip/header/header_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +package header + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SACKBlock) beforeSave() {} +func (x *SACKBlock) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *SACKBlock) afterLoad() {} +func (x *SACKBlock) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *TCPOptions) beforeSave() {} +func (x *TCPOptions) save(m state.Map) { + x.beforeSave() + m.Save("TS", &x.TS) + m.Save("TSVal", &x.TSVal) + m.Save("TSEcr", &x.TSEcr) + m.Save("SACKBlocks", &x.SACKBlocks) +} + +func (x *TCPOptions) afterLoad() {} +func (x *TCPOptions) load(m state.Map) { + m.Load("TS", &x.TS) + m.Load("TSVal", &x.TSVal) + m.Load("TSEcr", &x.TSEcr) + m.Load("SACKBlocks", &x.SACKBlocks) +} + +func init() { + state.Register("header.SACKBlock", (*SACKBlock)(nil), state.Fns{Save: (*SACKBlock).save, Load: (*SACKBlock).load}) + state.Register("header.TCPOptions", (*TCPOptions)(nil), state.Fns{Save: (*TCPOptions).save, Load: (*TCPOptions).load}) +} diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go deleted file mode 100644 index b5540bf66..000000000 --- a/pkg/tcpip/header/ipversion_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestIPv4(t *testing.T) { - b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) - - const want = header.IPv4Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestIPv6(t *testing.T) { - b := header.IPv6(make([]byte, header.IPv6MinimumSize)) - b.Encode(&header.IPv6Fields{}) - - const want = header.IPv6Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestOtherVersion(t *testing.T) { - const want = header.IPv4Version + header.IPv6Version - b := make([]byte, 1) - b[0] = want << 4 - - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestTooShort(t *testing.T) { - b := make([]byte, 1) - b[0] = (header.IPv4Version + header.IPv6Version) << 4 - - // Get the version of a zero-length slice. - const want = -1 - if v := header.IPVersion(b[:0]); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } - - // Get the version of a nil slice. - if v := header.IPVersion(nil); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go index 505c92668..505c92668 100644..100755 --- a/pkg/tcpip/header/ndp_neighbor_advert.go +++ b/pkg/tcpip/header/ndp_neighbor_advert.go diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go index 3a1b8e139..3a1b8e139 100644..100755 --- a/pkg/tcpip/header/ndp_neighbor_solicit.go +++ b/pkg/tcpip/header/ndp_neighbor_solicit.go diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 06e0bace2..06e0bace2 100644..100755 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go index bf7610863..bf7610863 100644..100755 --- a/pkg/tcpip/header/ndp_router_advert.go +++ b/pkg/tcpip/header/ndp_router_advert.go diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go deleted file mode 100644 index 2c439d70c..000000000 --- a/pkg/tcpip/header/ndp_test.go +++ /dev/null @@ -1,785 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "bytes" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" -) - -// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. -func TestNDPNeighborSolicit(t *testing.T) { - b := []byte{ - 0, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - } - - // Test getting the Target Address. - ns := NDPNeighborSolicit(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") - if got := ns.TargetAddress(); got != addr { - t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) - } - - // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") - ns.SetTargetAddress(addr2) - if got := ns.TargetAddress(); got != addr2 { - t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) - } - // Make sure the address got updated in the backing buffer. - if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { - t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) - } -} - -// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert. -func TestNDPNeighborAdvert(t *testing.T) { - b := []byte{ - 160, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - } - - // Test getting the Target Address. - na := NDPNeighborAdvert(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") - if got := na.TargetAddress(); got != addr { - t.Errorf("got TargetAddress = %s, want %s", got, addr) - } - - // Test getting the Router Flag. - if got := na.RouterFlag(); !got { - t.Errorf("got RouterFlag = false, want = true") - } - - // Test getting the Solicited Flag. - if got := na.SolicitedFlag(); got { - t.Errorf("got SolicitedFlag = true, want = false") - } - - // Test getting the Override Flag. - if got := na.OverrideFlag(); !got { - t.Errorf("got OverrideFlag = false, want = true") - } - - // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") - na.SetTargetAddress(addr2) - if got := na.TargetAddress(); got != addr2 { - t.Errorf("got TargetAddress = %s, want %s", got, addr2) - } - // Make sure the address got updated in the backing buffer. - if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { - t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) - } - - // Test updating the Router Flag. - na.SetRouterFlag(false) - if got := na.RouterFlag(); got { - t.Errorf("got RouterFlag = true, want = false") - } - - // Test updating the Solicited Flag. - na.SetSolicitedFlag(true) - if got := na.SolicitedFlag(); !got { - t.Errorf("got SolicitedFlag = false, want = true") - } - - // Test updating the Override Flag. - na.SetOverrideFlag(false) - if got := na.OverrideFlag(); got { - t.Errorf("got OverrideFlag = true, want = false") - } - - // Make sure flags got updated in the backing buffer. - if got := b[ndpNAFlagsOffset]; got != 64 { - t.Errorf("got flags byte = %d, want = 64") - } -} - -func TestNDPRouterAdvert(t *testing.T) { - b := []byte{ - 64, 128, 1, 2, - 3, 4, 5, 6, - 7, 8, 9, 10, - } - - ra := NDPRouterAdvert(b) - - if got := ra.CurrHopLimit(); got != 64 { - t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) - } - - if got := ra.ManagedAddrConfFlag(); !got { - t.Errorf("got ManagedAddrConfFlag = false, want = true") - } - - if got := ra.OtherConfFlag(); got { - t.Errorf("got OtherConfFlag = true, want = false") - } - - if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) - } - - if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) - } - - if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) - } -} - -// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the -// Ethernet address from an NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { - tests := []struct { - name string - buf []byte - expected tcpip.LinkAddress - }{ - { - "ValidMAC", - []byte{1, 2, 3, 4, 5, 6}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - { - "TLLBodyTooShort", - []byte{1, 2, 3, 4, 5}, - tcpip.LinkAddress([]byte(nil)), - }, - { - "TLLBodyLargerThanNeeded", - []byte{1, 2, 3, 4, 5, 6, 7, 8}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tll := NDPTargetLinkLayerAddressOption(test.buf) - if got := tll.EthernetAddress(); got != test.expected { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) - } - }) - } - -} - -// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a -// NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedBuf []byte - addr tcpip.LinkAddress - }{ - { - "Ethernet", - make([]byte, 8), - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", - }, - { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", - }, - { - "Empty", - []byte{}, - []byte{}, - "", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - NDPTargetLinkLayerAddressOption(test.addr), - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - tll := next.(NDPTargetLinkLayerAddressOption) - if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got tll.MACAddress = %s, want = %s", got, want) - } - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPPrefixInformationOption tests the field getters and serialization of a -// NDPPrefixInformation. -func TestNDPPrefixInformationOption(t *testing.T) { - b := []byte{ - 43, 127, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 5, 5, 5, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPPrefixInformation(b), - } - opts.Serialize(serializer) - expectedBuf := []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - if !bytes.Equal(targetBuf, expectedBuf) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - pi := next.(NDPPrefixInformation) - - if got := pi.Type(); got != 3 { - t.Errorf("got Type = %d, want = 3", got) - } - - if got := pi.Length(); got != 30 { - t.Errorf("got Length = %d, want = 30", got) - } - - if got := pi.PrefixLength(); got != 43 { - t.Errorf("got PrefixLength = %d, want = 43", got) - } - - if pi.OnLinkFlag() { - t.Error("got OnLinkFlag = true, want = false") - } - - if !pi.AutonomousAddressConfigurationFlag() { - t.Error("got AutonomousAddressConfigurationFlag = false, want = true") - } - - if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { - t.Errorf("got ValidLifetime = %d, want = %d", got, want) - } - - if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { - t.Errorf("got PreferredLifetime = %d, want = %d", got, want) - } - - if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { - t.Errorf("got Prefix = %s, want = %s", got, want) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 25, 3, 0, 0, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPRecursiveDNSServer(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Type(); got != 25 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16909320*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - want := []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - } - if got := opt.Addresses(); !cmp.Equal(got, want) { - t.Errorf("got Addresses = %v, want = %v", got, want) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOption(t *testing.T) { - tests := []struct { - name string - buf []byte - lifetime time.Duration - addrs []tcpip.Address - }{ - { - "Valid1Addr", - []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - }, - }, - { - "Valid2Addr", - []byte{ - 25, 5, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", - }, - }, - { - "Valid3Addr", - []byte{ - 25, 7, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - // Iterator should get our option. - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Lifetime(); got != test.lifetime { - t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) - } - if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { - t.Errorf("got Addresses = %v, want = %v", got, test.addrs) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions -// the iterator was returned for is malformed. -func TestNDPOptionsIterCheck(t *testing.T) { - tests := []struct { - name string - buf []byte - expected error - }{ - { - "ZeroLengthField", - []byte{0, 0, 0, 0, 0, 0, 0, 0}, - ErrNDPOptZeroLength, - }, - { - "ValidTargetLinkLayerAddressOption", - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - nil, - }, - { - "TooSmallTargetLinkLayerAddressOption", - []byte{2, 1, 1, 2, 3, 4, 5}, - ErrNDPOptBufExhausted, - }, - { - "ValidPrefixInformation", - []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - nil, - }, - { - "TooSmallPrefixInformation", - []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, - }, - ErrNDPOptBufExhausted, - }, - { - "InvalidPrefixInformationLength", - []byte{ - 3, 3, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - }, - ErrNDPOptMalformedBody, - }, - { - "ValidTargetLinkLayerAddressWithPrefixInformation", - []byte{ - // Target Link-Layer Address. - 2, 1, 1, 2, 3, 4, 5, 6, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - nil, - }, - { - "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", - []byte{ - // Target Link-Layer Address. - 2, 1, 1, 2, 3, 4, 5, 6, - - // 255 is an unrecognized type. If 255 ends up - // being the type for some recognized type, - // update 255 to some other unrecognized value. - 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - nil, - }, - { - "InvalidRecursiveDNSServerCutsOffAddress", - []byte{ - 25, 4, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 1, 2, 3, 4, 5, 6, 7, - }, - ErrNDPOptMalformedBody, - }, - { - "InvalidRecursiveDNSServerInvalidLengthField", - []byte{ - 25, 2, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, - }, - ErrNDPInvalidLength, - }, - { - "RecursiveDNSServerTooSmall", - []byte{ - 25, 1, 0, 0, - 0, 0, 0, - }, - ErrNDPOptBufExhausted, - }, - { - "RecursiveDNSServerMulticast", - []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }, - ErrNDPOptMalformedBody, - }, - { - "RecursiveDNSServerUnspecified", - []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, - ErrNDPOptMalformedBody, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - - if _, err := opts.Iter(true); err != test.expected { - t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) - } - - // test.buf may be malformed but we chose not to check - // the iterator so it must return true. - if _, err := opts.Iter(false); err != nil { - t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) - } - }) - } -} - -// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, -// this test does not actually check any of the option's getters, it simply -// checks the option Type and Body. We have other tests that tests the option -// field gettings given an option body and don't need to duplicate those tests -// here. -func TestNDPOptionsIter(t *testing.T) { - buf := []byte{ - // Target Link-Layer Address. - 2, 1, 1, 2, 3, 4, 5, 6, - - // 255 is an unrecognized type. If 255 ends up being the type - // for some recognized type, update 255 to some other - // unrecognized value. Note, this option should be skipped when - // iterating. - 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - opts := NDPOptions(buf) - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - // Test the first (Taret Link-Layer) option. - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - - // Test the next (Prefix Information) option. - // Note, the unrecognized option should be skipped. - next, done, err = it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go deleted file mode 100644 index 72563837b..000000000 --- a/pkg/tcpip/header/tcp_test.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestEncodeSACKBlocks(t *testing.T) { - testCases := []struct { - sackBlocks []header.SACKBlock - want []header.SACKBlock - bufSize int - }{ - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 40, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, - 30, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}}, - 20, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}}, - 10, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - nil, - 8, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 60, - }, - } - for _, tc := range testCases { - b := make([]byte, tc.bufSize) - t.Logf("testing: %v", tc) - header.EncodeSACKBlocks(tc.sackBlocks, b) - opts := header.ParseTCPOptions(b) - if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want) - } - } -} - -func TestTCPParseOptions(t *testing.T) { - type tsOption struct { - tsVal uint32 - tsEcr uint32 - } - - generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte { - l := 0 - if tsOpt != nil { - l += 10 - } - if len(sackBlocks) != 0 { - l += len(sackBlocks)*8 + 2 - } - b := make([]byte, l) - offset := 0 - if tsOpt != nil { - offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b) - } - header.EncodeSACKBlocks(sackBlocks, b[offset:]) - return b - } - - testCases := []struct { - b []byte - want header.TCPOptions - }{ - // Trivial cases. - {nil, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test timestamp parsing. - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - - // Test malformed timestamp option. - {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test SACKBlock parsing. - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}}, - {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}}, - - // Test malformed SACK option. - {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}}, - - // Test Timestamp + SACK block parsing. - {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}}, - - // Test valid timestamp + malformed SACK block parsing. - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - } - for _, tc := range testCases { - if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want) - } - } -} diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD deleted file mode 100644 index cc5f531e2..000000000 --- a/pkg/tcpip/iptables/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "iptables", - srcs = [ - "iptables.go", - "targets.go", - "types.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", - visibility = ["//visibility:public"], - deps = ["//pkg/tcpip/buffer"], -) diff --git a/pkg/tcpip/iptables/iptables_state_autogen.go b/pkg/tcpip/iptables/iptables_state_autogen.go new file mode 100755 index 000000000..f15092db2 --- /dev/null +++ b/pkg/tcpip/iptables/iptables_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package iptables + diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD deleted file mode 100644 index 7dbc05754..000000000 --- a/pkg/tcpip/link/channel/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "channel", - srcs = ["channel.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 70188551f..70188551f 100644..100755 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go diff --git a/pkg/tcpip/link/channel/channel_state_autogen.go b/pkg/tcpip/link/channel/channel_state_autogen.go new file mode 100755 index 000000000..19e9e5a2b --- /dev/null +++ b/pkg/tcpip/link/channel/channel_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package channel + diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD deleted file mode 100644 index 897c94821..000000000 --- a/pkg/tcpip/link/fdbased/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fdbased", - srcs = [ - "endpoint.go", - "endpoint_unsafe.go", - "mmap.go", - "mmap_stub.go", - "mmap_unsafe.go", - "packet_dispatchers.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/stack", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fdbased_test", - size = "small", - srcs = ["endpoint_test.go"], - embed = [":fdbased"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go deleted file mode 100644 index 2066987eb..000000000 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ /dev/null @@ -1,468 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package fdbased - -import ( - "bytes" - "fmt" - "math/rand" - "reflect" - "syscall" - "testing" - "time" - "unsafe" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - mtu = 1500 - laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") - raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") - proto = 10 - csumOffset = 48 - gsoMSS = 500 -) - -type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents tcpip.PacketBuffer -} - -type context struct { - t *testing.T - fds [2]int - ep stack.LinkEndpoint - ch chan packetInfo - done chan struct{} -} - -func newContext(t *testing.T, opt *Options) *context { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) - if err != nil { - t.Fatalf("Socketpair failed: %v", err) - } - - done := make(chan struct{}, 1) - opt.ClosedFunc = func(*tcpip.Error) { - done <- struct{}{} - } - - opt.FDs = []int{fds[1]} - ep, err := New(opt) - if err != nil { - t.Fatalf("Failed to create FD endpoint: %v", err) - } - - c := &context{ - t: t, - fds: fds, - ep: ep, - ch: make(chan packetInfo, 100), - done: done, - } - - ep.Attach(c) - - return c -} - -func (c *context) cleanup() { - syscall.Close(c.fds[0]) - <-c.done - syscall.Close(c.fds[1]) -} - -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - c.ch <- packetInfo{remote, protocol, pkt} -} - -func TestNoEthernetProperties(t *testing.T) { - c := newContext(t, &Options{MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestEthernetProperties(t *testing.T) { - c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestAddress(t *testing.T) { - addrs := []tcpip.LinkAddress{"", "abc", "def"} - for _, a := range addrs { - t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { - c := newContext(t, &Options{Address: a, MTU: mtu}) - defer c.cleanup() - - if want, v := a, c.ep.LinkAddress(); want != v { - t.Fatalf("LinkAddress() = %v, want %v", v, want) - } - }) - } -} - -func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) - defer c.cleanup() - - r := &stack.Route{ - RemoteLinkAddress: raddr, - } - - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) - } - - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) - } - want := append(hdr.View(), payload...) - var gso *stack.GSO - if gsoMaxSize != 0 { - gso = &stack.GSO{ - Type: stack.GSOTCPv6, - NeedsCsum: true, - CsumOffset: csumOffset, - MSS: gsoMSS, - MaxSize: gsoMaxSize, - L3HdrLen: header.IPv4MaximumHeaderSize, - } - } - if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from fd, then compare with what we wrote. - b = make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - if gsoMaxSize != 0 { - vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0])) - if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { - t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) - } - csumStart := header.EthernetMinimumSize + gso.L3HdrLen - if vnetHdr.csumStart != csumStart { - t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) - } - if vnetHdr.csumOffset != csumOffset { - t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) - } - gsoType := uint8(0) - if int(gso.MSS) < plen { - gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 - } - if vnetHdr.gsoType != gsoType { - t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType) - } - b = b[virtioNetHdrSize:] - } - if eth { - h := header.Ethernet(b) - b = b[header.EthernetMinimumSize:] - - if a := h.SourceAddress(); a != laddr { - t.Fatalf("SourceAddress() = %v, want %v", a, laddr) - } - - if a := h.DestinationAddress(); a != raddr { - t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) - } - - if et := h.Type(); et != proto { - t.Fatalf("Type() = %v, want %v", et, proto) - } - } - if len(b) != len(want) { - t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) - } - if !bytes.Equal(b, want) { - t.Fatalf("Read returned %x, want %x", b, want) - } -} - -func TestWritePacket(t *testing.T) { - lengths := []int{0, 100, 1000} - eths := []bool{true, false} - gsos := []uint32{0, 32768} - - for _, eth := range eths { - for _, plen := range lengths { - for _, gso := range gsos { - t.Run( - fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), - func(t *testing.T) { - testWritePacket(t, plen, eth, gso) - }, - ) - } - } - } -} - -func TestPreserveSrcAddress(t *testing.T) { - baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") - - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true}) - defer c.cleanup() - - // Set LocalLinkAddress in route to the value of the bridged address. - r := &stack.Route{ - RemoteLinkAddress: raddr, - LocalLinkAddress: baddr, - } - - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ - Header: hdr, - Data: buffer.VectorisedView{}, - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from the FD, then compare with what we wrote. - b := make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - h := header.Ethernet(b) - - if a := h.SourceAddress(); a != baddr { - t.Fatalf("SourceAddress() = %v, want %v", a, baddr) - } -} - -func TestDeliverPacket(t *testing.T) { - lengths := []int{100, 1000} - eths := []bool{true, false} - - for _, eth := range eths { - for _, plen := range lengths { - t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) - defer c.cleanup() - - // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) - } - - var hdr header.Ethernet - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr = make(header.Ethernet, header.EthernetMinimumSize) - hdr.Encode(&header.EthernetFields{ - SrcAddr: raddr, - DstAddr: laddr, - Type: proto, - }) - all = append(hdr, b...) - } - - // Write packet via the file descriptor. - if _, err := syscall.Write(c.fds[0], all); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Receive packet through the endpoint. - select { - case pi := <-c.ch: - want := packetInfo{ - raddr: raddr, - proto: proto, - contents: tcpip.PacketBuffer{ - Data: buffer.View(b).ToVectorisedView(), - LinkHeader: buffer.View(hdr), - }, - } - if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - // want.contents.Data will be a single - // view, so make pi do the same for the - // DeepEqual check. - pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) - } - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for packet") - } - }) - } - } -} - -func TestBufConfigMaxLength(t *testing.T) { - got := 0 - for _, i := range BufConfig { - got += i - } - want := header.MaxIPPacketSize // maximum TCP packet size - if got < want { - t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) - } -} - -func TestBufConfigFirst(t *testing.T) { - // The stack assumes that the TCP/IP header is enterily contained in the first view. - // Therefore, the first view needs to be large enough to contain the maximum TCP/IP - // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). - want := 120 - got := BufConfig[0] - if got < want { - t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) - } -} - -var capLengthTestCases = []struct { - comment string - config []int - n int - wantUsed int - wantLengths []int -}{ - { - comment: "Single slice", - config: []int{2}, - n: 1, - wantUsed: 1, - wantLengths: []int{1}, - }, - { - comment: "Multiple slices", - config: []int{1, 2}, - n: 2, - wantUsed: 2, - wantLengths: []int{1, 1}, - }, - { - comment: "Entire buffer", - config: []int{1, 2}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2}, - }, - { - comment: "Entire buffer but not on the last slice", - config: []int{1, 2, 3}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2, 3}, - }, -} - -func TestReadVDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - // fd does not matter for this test. - d := readVDispatcher{fd: -1, e: &endpoint{}} - d.views = make([]buffer.View, len(c.config)) - d.iovecs = make([]syscall.Iovec, len(c.config)) - d.allocateViews(c.config) - - used := d.capViews(c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views)) - for i, v := range d.views { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - } -} - -func TestRecvMMsgDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - d := recvMMsgDispatcher{ - fd: -1, // fd does not matter for this test. - e: &endpoint{}, - views: make([][]buffer.View, 1), - iovecs: make([][]syscall.Iovec, 1), - msgHdrs: make([]rawfile.MMsgHdr, 1), - } - - for i, _ := range d.views { - d.views[i] = make([]buffer.View, len(c.config)) - } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, len(c.config)) - } - for k, msgHdr := range d.msgHdrs { - msgHdr.Msg.Iov = &d.iovecs[k][0] - msgHdr.Msg.Iovlen = uint64(len(c.config)) - } - - d.allocateViews(c.config) - - used := d.capViews(0, c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views[0])) - for i, v := range d.views[0] { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - - } -} diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go new file mode 100755 index 000000000..0555db528 --- /dev/null +++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package fdbased + diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD deleted file mode 100644 index f35fcdff4..000000000 --- a/pkg/tcpip/link/loopback/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "loopback", - srcs = ["loopback.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go new file mode 100755 index 000000000..87ec8cfc7 --- /dev/null +++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package loopback + diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD deleted file mode 100644 index 1ac7948b6..000000000 --- a/pkg/tcpip/link/muxed/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "muxed", - srcs = ["injectable.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "muxed_test", - size = "small", - srcs = ["injectable_test.go"], - embed = [":muxed"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go deleted file mode 100644 index 445b22c17..000000000 --- a/pkg/tcpip/link/muxed/injectable.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package muxed provides a muxed link endpoints. -package muxed - -import ( - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// InjectableEndpoint is an injectable multi endpoint. The endpoint has -// trivial routing rules that determine which InjectableEndpoint a given packet -// will be written to. Note that HandleLocal works differently for this -// endpoint (see WritePacket). -type InjectableEndpoint struct { - routes map[tcpip.Address]stack.InjectableLinkEndpoint - dispatcher stack.NetworkDispatcher -} - -// MTU implements stack.LinkEndpoint. -func (m *InjectableEndpoint) MTU() uint32 { - minMTU := ^uint32(0) - for _, endpoint := range m.routes { - if endpointMTU := endpoint.MTU(); endpointMTU < minMTU { - minMTU = endpointMTU - } - } - return minMTU -} - -// Capabilities implements stack.LinkEndpoint. -func (m *InjectableEndpoint) Capabilities() stack.LinkEndpointCapabilities { - minCapabilities := stack.LinkEndpointCapabilities(^uint(0)) - for _, endpoint := range m.routes { - minCapabilities &= endpoint.Capabilities() - } - return minCapabilities -} - -// MaxHeaderLength implements stack.LinkEndpoint. -func (m *InjectableEndpoint) MaxHeaderLength() uint16 { - minHeaderLen := ^uint16(0) - for _, endpoint := range m.routes { - if headerLen := endpoint.MaxHeaderLength(); headerLen < minHeaderLen { - minHeaderLen = headerLen - } - } - return minHeaderLen -} - -// LinkAddress implements stack.LinkEndpoint. -func (m *InjectableEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -// Attach implements stack.LinkEndpoint. -func (m *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - for _, endpoint := range m.routes { - endpoint.Attach(dispatcher) - } - m.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint. -func (m *InjectableEndpoint) IsAttached() bool { - return m.dispatcher != nil -} - -// InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) -} - -// WritePackets writes outbound packets to the appropriate -// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if -// r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - endpoint, ok := m.routes[r.RemoteAddress] - if !ok { - return 0, tcpip.ErrNoRoute - } - return endpoint.WritePackets(r, gso, pkts, protocol) -} - -// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint -// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a -// route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, gso, protocol, pkt) - } - return tcpip.ErrNoRoute -} - -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // WriteRawPacket doesn't get a route or network address, so there's - // nowhere to write this. - return tcpip.ErrNoRoute -} - -// InjectOutbound writes outbound packets to the appropriate -// LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { - endpoint, ok := m.routes[dest] - if !ok { - return tcpip.ErrNoRoute - } - return endpoint.InjectOutbound(dest, packet) -} - -// Wait implements stack.LinkEndpoint.Wait. -func (m *InjectableEndpoint) Wait() { - for _, ep := range m.routes { - ep.Wait() - } -} - -// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. -func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { - return &InjectableEndpoint{ - routes: routes, - } -} diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go deleted file mode 100644 index 63b249837..000000000 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package muxed - -import ( - "bytes" - "net" - "os" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -func TestInjectableEndpointRawDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - endpoint.InjectOutbound(dstIP, []byte{0xFA}) - - buf := make([]byte, ipv4.MaxTotalSize) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) - - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA, 0xFB}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buffer.NewView(0).ToVectorisedView(), - }) - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tcpip.Address) { - dstIP := tcpip.Address(net.ParseIP("1.2.3.4").To4()) - pair, err := syscall.Socketpair(syscall.AF_UNIX, - syscall.SOCK_SEQPACKET|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0) - if err != nil { - t.Fatal("Failed to create socket pair:", err) - } - underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) - routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - endpoint := NewInjectableEndpoint(routes) - return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP -} diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD deleted file mode 100644 index d8211e93d..000000000 --- a/pkg/tcpip/link/rawfile/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "rawfile", - srcs = [ - "blockingpoll_amd64.s", - "blockingpoll_arm64.s", - "blockingpoll_noyield_unsafe.go", - "blockingpoll_yield_unsafe.go", - "errors.go", - "rawfile_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go new file mode 100755 index 000000000..662c04444 --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package rawfile + diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD deleted file mode 100644 index a4f9cdd69..000000000 --- a/pkg/tcpip/link/sharedmem/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sharedmem", - srcs = [ - "rx.go", - "sharedmem.go", - "sharedmem_unsafe.go", - "tx.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "sharedmem_test", - srcs = [ - "sharedmem_test.go", - ], - embed = [":sharedmem"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/sharedmem/pipe", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD deleted file mode 100644 index 6b5bc542c..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = [ - "pipe.go", - "pipe_unsafe.go", - "rx.go", - "tx.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", - visibility = ["//visibility:public"], -) - -go_test( - name = "pipe_test", - srcs = [ - "pipe_test.go", - ], - embed = [":pipe"], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe.go b/pkg/tcpip/link/sharedmem/pipe/pipe.go deleted file mode 100644 index 74c9f0311..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package pipe implements a shared memory ring buffer on which a single reader -// and a single writer can operate (read/write) concurrently. The ring buffer -// allows for data of different sizes to be written, and preserves the boundary -// of the written data. -// -// Example usage is as follows: -// -// wb := t.Push(20) -// // Write data to wb. -// t.Flush() -// -// rb := r.Pull() -// // Do something with data in rb. -// t.Flush() -package pipe - -import ( - "math" -) - -const ( - jump uint64 = math.MaxUint32 + 1 - offsetMask uint64 = math.MaxUint32 - revolutionMask uint64 = ^offsetMask - - sizeOfSlotHeader = 8 // sizeof(uint64) - slotFree uint64 = 1 << 63 - slotSizeMask uint64 = math.MaxUint32 -) - -// payloadToSlotSize calculates the total size of a slot based on its payload -// size. The total size is the header size, plus the payload size, plus padding -// if necessary to make the total size a multiple of sizeOfSlotHeader. -func payloadToSlotSize(payloadSize uint64) uint64 { - s := sizeOfSlotHeader + payloadSize - return (s + sizeOfSlotHeader - 1) &^ (sizeOfSlotHeader - 1) -} - -// slotToPayloadSize calculates the payload size of a slot based on the total -// size of the slot. This is only meant to be used when creating slots that -// don't carry information (e.g., free slots or wrap slots). -func slotToPayloadSize(offset uint64) uint64 { - return offset - sizeOfSlotHeader -} - -// pipe is a basic data structure used by both (transmit & receive) ends of a -// pipe. Indices into this pipe are split into two fields: offset, which counts -// the number of bytes from the beginning of the buffer, and revolution, which -// counts the number of times the index has wrapped around. -type pipe struct { - buffer []byte -} - -// init initializes the pipe buffer such that its size is a multiple of the size -// of the slot header. -func (p *pipe) init(b []byte) { - p.buffer = b[:len(b)&^(sizeOfSlotHeader-1)] -} - -// data returns a section of the buffer starting at the given index (which may -// include revolution information) and with the given size. -func (p *pipe) data(idx uint64, size uint64) []byte { - return p.buffer[(idx&offsetMask)+sizeOfSlotHeader:][:size] -} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go deleted file mode 100644 index 59ef69a8b..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ /dev/null @@ -1,517 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "math/rand" - "reflect" - "runtime" - "sync" - "testing" -) - -func TestSimpleReadWrite(t *testing.T) { - // Check that a simple write can be properly read from the rx side. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - wb := tx.Push(10) - if wb == nil { - t.Fatalf("Push failed on empty pipe") - } - for i := range wb { - wb[i] = byte(tr.Intn(256)) - } - tx.Flush() - - var rx Rx - rx.Init(b) - rb := rx.Pull() - if len(rb) != 10 { - t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10) - } - - for i := range rb { - if v := byte(rr.Intn(256)); v != rb[i] { - t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v) - } - } - rx.Flush() -} - -func TestEmptyRead(t *testing.T) { - // Check that pulling from an empty pipe fails. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestTooLargeWrite(t *testing.T) { - // Check that writes that are too large are properly rejected. - b := make([]byte, 96) - var tx Tx - tx.Init(b) - - if wb := tx.Push(96); wb != nil { - t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(88); wb != nil { - t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } -} - -func TestFullWrite(t *testing.T) { - // Check that writes fail when the pipe is full. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestFullAndFlushedWrite(t *testing.T) { - // Check that writes fail when the pipe is full and has already been - // flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - tx.Flush() - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestTxFlushTwice(t *testing.T) { - // Checks that a second consecutive tx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // Make copy of original tx queue, flush it, then check that it didn't - // change. - orig := tx - tx.Flush() - - if !reflect.DeepEqual(orig, tx) { - t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig) - } -} - -func TestRxFlushTwice(t *testing.T) { - // Checks that a second consecutive rx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // Make copy of original rx queue, flush it, then check that it didn't - // change. - orig := rx - rx.Flush() - - if !reflect.DeepEqual(orig, rx) { - t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig) - } -} - -func TestWrapInMiddleOfTransaction(t *testing.T) { - // Check that writes are not flushed when we need to wrap the buffer - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestWriteAbort(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it but - // has aborted the push. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Abort() - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestWrappedWriteAbort(t *testing.T) { - // Check that writes are properly aborted even if the writes wrap - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Abort() - - // The pushes were aborted, so no data should be readable. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - // Try the same transactions again, but flush this time. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestEmptyReadOnNonFlushedWrite(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it - // but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Flush() - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull on failed on non-empty pipe") - } -} - -func TestPullAfterPullingEntirePipe(t *testing.T) { - // Check that Pull fails when the pipe is full, but all of it has - // already been pulled but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3 - // buffers that will fill the pipe. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - if wb := tx.Push(24); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The three buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - // Fourth pull must fail. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestNoRoomToWrapOnPush(t *testing.T) { - // Check that Push fails when it tries to allocate room to add a wrap - // message. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20, - // which won't fit (64+20+8+padding = 96, which wouldn't leave room for - // the padding), so it wraps around. - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - tx.Flush() - - // Buffer offset is at 28. Try to write 70, which would require a wrap - // slot which cannot be created now. - if wb := tx.Push(70); wb != nil { - t.Fatalf("Push succeeded on pipe with no room for wrap message") - } -} - -func TestRxImplicitFlushOfWrapMessage(t *testing.T) { - // Check if the first read is that of a wrapping message, that it gets - // immediately flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // This will cause a wrapping message to written. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - var rx Rx - rx.Init(b) - - // Read the first message. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // This should fail because of the wrapping message is taking up space. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - // Try to read the next one. This should consume the wrapping message. - rx.Pull() - - // This must now succeed. - if wb := tx.Push(60); wb == nil { - t.Fatalf("Push failed on empty pipe") - } -} - -func TestConcurrentReaderWriter(t *testing.T) { - // Push a million buffers of random sizes and random contents. Check - // that buffers read match what was written. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - - const count = 1000000 - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + tr.Intn(80) - wb := tx.Push(uint64(n)) - for wb == nil { - wb = tx.Push(uint64(n)) - } - - for j := range wb { - wb[j] = byte(tr.Intn(256)) - } - - tx.Flush() - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } - - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } - - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } - } - - rx.Flush() - } - }() - - wg.Wait() -} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go deleted file mode 100644 index 62d17029e..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "sync/atomic" - "unsafe" -) - -func (p *pipe) write(idx uint64, v uint64) { - ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) - *ptr = v -} - -func (p *pipe) writeAtomic(idx uint64, v uint64) { - ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) - atomic.StoreUint64(ptr, v) -} - -func (p *pipe) readAtomic(idx uint64) uint64 { - ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) - return atomic.LoadUint64(ptr) -} diff --git a/pkg/tcpip/link/sharedmem/pipe/rx.go b/pkg/tcpip/link/sharedmem/pipe/rx.go deleted file mode 100644 index f22e533ac..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/rx.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -// Rx is the receive side of the shared memory ring buffer. -type Rx struct { - p pipe - - tail uint64 - head uint64 -} - -// Init initializes the receive end of the pipe. In the initial state, the next -// slot to be inspected is the very first one. -func (r *Rx) Init(b []byte) { - r.p.init(b) - r.tail = 0xfffffffe * jump - r.head = r.tail -} - -// Pull reads the next buffer from the pipe, returning nil if there isn't one -// currently available. -// -// The returned slice is available until Flush() is next called. After that, it -// must not be touched. -func (r *Rx) Pull() []byte { - if r.head == r.tail+jump { - // We've already pulled the whole pipe. - return nil - } - - header := r.p.readAtomic(r.head) - if header&slotFree != 0 { - // The next slot is free, we can't pull it yet. - return nil - } - - payloadSize := header & slotSizeMask - newHead := r.head + payloadToSlotSize(payloadSize) - headWrap := (r.head & revolutionMask) | uint64(len(r.p.buffer)) - - // Check if this is a wrapping slot. If that's the case, it carries no - // data, so we just skip it and try again from the first slot. - if int64(newHead-headWrap) >= 0 { - if int64(newHead-headWrap) > int64(jump) || newHead&offsetMask != 0 { - return nil - } - - if r.tail == r.head { - // If this is the first pull since the last Flush() - // call, we flush the state so that the sender can use - // this space if it needs to. - r.p.writeAtomic(r.head, slotFree|slotToPayloadSize(newHead-r.head)) - r.tail = newHead - } - - r.head = newHead - return r.Pull() - } - - // Grab the buffer before updating r.head. - b := r.p.data(r.head, payloadSize) - r.head = newHead - return b -} - -// Flush tells the transmitter that all buffers pulled since the last Flush() -// have been used, so the transmitter is free to used their slots for further -// transmission. -func (r *Rx) Flush() { - if r.head == r.tail { - return - } - r.p.writeAtomic(r.tail, slotFree|slotToPayloadSize(r.head-r.tail)) - r.tail = r.head -} - -// Bytes returns the byte slice on which the pipe operates. -func (r *Rx) Bytes() []byte { - return r.p.buffer -} diff --git a/pkg/tcpip/link/sharedmem/pipe/tx.go b/pkg/tcpip/link/sharedmem/pipe/tx.go deleted file mode 100644 index 9841eb231..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/tx.go +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -// Tx is the transmit side of the shared memory ring buffer. -type Tx struct { - p pipe - maxPayloadSize uint64 - - head uint64 - tail uint64 - next uint64 - - tailHeader uint64 -} - -// Init initializes the transmit end of the pipe. In the initial state, the next -// slot to be written is the very first one, and the transmitter has the whole -// ring buffer available to it. -func (t *Tx) Init(b []byte) { - t.p.init(b) - // maxPayloadSize excludes the header of the payload, and the header - // of the wrapping message. - t.maxPayloadSize = uint64(len(t.p.buffer)) - 2*sizeOfSlotHeader - t.tail = 0xfffffffe * jump - t.next = t.tail - t.head = t.tail + jump - t.p.write(t.tail, slotFree) -} - -// Capacity determines how many records of the given size can be written to the -// pipe before it fills up. -func (t *Tx) Capacity(recordSize uint64) uint64 { - available := uint64(len(t.p.buffer)) - sizeOfSlotHeader - entryLen := payloadToSlotSize(recordSize) - return available / entryLen -} - -// Push reserves "payloadSize" bytes for transmission in the pipe. The caller -// populates the returned slice with the data to be transferred and enventually -// calls Flush() to make the data visible to the reader, or Abort() to make the -// pipe forget all Push() calls since the last Flush(). -// -// The returned slice is available until Flush() or Abort() is next called. -// After that, it must not be touched. -func (t *Tx) Push(payloadSize uint64) []byte { - // Fail request if we know we will never have enough room. - if payloadSize > t.maxPayloadSize { - return nil - } - - totalLen := payloadToSlotSize(payloadSize) - newNext := t.next + totalLen - nextWrap := (t.next & revolutionMask) | uint64(len(t.p.buffer)) - if int64(newNext-nextWrap) >= 0 { - // The new buffer would overflow the pipe, so we push a wrapping - // slot, then try to add the actual slot to the front of the - // pipe. - newNext = (newNext & revolutionMask) + jump - wrappingPayloadSize := slotToPayloadSize(newNext - t.next) - if !t.reclaim(newNext) { - return nil - } - - oldNext := t.next - t.next = newNext - if oldNext != t.tail { - t.p.write(oldNext, wrappingPayloadSize) - } else { - t.tailHeader = wrappingPayloadSize - t.Flush() - } - - newNext += totalLen - } - - // Check that we have enough room for the buffer. - if !t.reclaim(newNext) { - return nil - } - - if t.next != t.tail { - t.p.write(t.next, payloadSize) - } else { - t.tailHeader = payloadSize - } - - // Grab the buffer before updating t.next. - b := t.p.data(t.next, payloadSize) - t.next = newNext - - return b -} - -// reclaim attempts to advance the head until at least newNext. If the head is -// already at or beyond newNext, nothing happens and true is returned; otherwise -// it tries to reclaim slots that have already been consumed by the receive end -// of the pipe (they will be marked as free) and returns a boolean indicating -// whether it was successful in reclaiming enough slots. -func (t *Tx) reclaim(newNext uint64) bool { - for int64(newNext-t.head) > 0 { - // Can't reclaim if slot is not free. - header := t.p.readAtomic(t.head) - if header&slotFree == 0 { - return false - } - - payloadSize := header & slotSizeMask - newHead := t.head + payloadToSlotSize(payloadSize) - - // Check newHead is within bounds and valid. - if int64(newHead-t.tail) > int64(jump) || newHead&offsetMask >= uint64(len(t.p.buffer)) { - return false - } - - t.head = newHead - } - - return true -} - -// Abort causes all Push() calls since the last Flush() to be forgotten and -// therefore they will not be made visible to the receiver. -func (t *Tx) Abort() { - t.next = t.tail -} - -// Flush causes all buffers pushed since the last Flush() [or Abort(), whichever -// is the most recent] to be made visible to the receiver. -func (t *Tx) Flush() { - if t.next == t.tail { - // Nothing to do if there are no pushed buffers. - return - } - - if t.next != t.head { - // The receiver will spin in t.next, so we must make sure that - // the slotFree bit is set. - t.p.write(t.next, slotFree) - } - - t.p.writeAtomic(t.tail, t.tailHeader) - t.tail = t.next -} - -// Bytes returns the byte slice on which the pipe operates. -func (t *Tx) Bytes() []byte { - return t.p.buffer -} diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD deleted file mode 100644 index 8c9234d54..000000000 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "queue", - srcs = [ - "rx.go", - "tx.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip/link/sharedmem/pipe", - ], -) - -go_test( - name = "queue_test", - srcs = [ - "queue_test.go", - ], - embed = [":queue"], - deps = [ - "//pkg/tcpip/link/sharedmem/pipe", - ], -) diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go deleted file mode 100644 index 9a0aad5d7..000000000 --- a/pkg/tcpip/link/sharedmem/queue/queue_test.go +++ /dev/null @@ -1,517 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package queue - -import ( - "encoding/binary" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -func TestBasicTxQueue(t *testing.T) { - // Tests that a basic transmit on a queue works, and that completion - // gets properly reported as well. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Enqueue two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on empty queue") - } - - // Check the contents of the pipe. - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - - want := []byte{ - 234, 3, 0, 0, 0, 0, 0, 0, // id - 100, 0, 0, 0, // total size - 0, 0, 0, 0, // reserved - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - } - - if !reflect.DeepEqual(want, d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want) - } - - rxp.Flush() - - // Check that there are no completions yet. - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Packet reported as completed too soon") - } - - // Post a completion. - d = txp.Push(8) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - - // Check that completion is properly reported. - id, ok := q.CompletedPacket() - if !ok { - t.Fatalf("Completion not reported") - } - - if id != usedID { - t.Fatalf("Bad completion id: got %v, want %v", id, usedID) - } -} - -func TestBasicRxQueue(t *testing.T) { - // Tests that a basic receive on a queue works. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on empty queue") - } - - // Check the contents of the pipe. - want := [][]byte{ - { - 100, 0, 0, 0, 0, 0, 0, 0, // Offset1 - 60, 0, 0, 0, // Size1 - 0, 0, 0, 0, // Remaining in group 1 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - }, - { - 200, 0, 0, 0, 0, 0, 0, 0, // Offset2 - 40, 0, 0, 0, // Size2 - 0, 0, 0, 0, // Remaining in group 2 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }, - } - - for i := range b { - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - - if !reflect.DeepEqual(want[i], d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want[i]) - } - - rxp.Flush() - } - - // Check that there are no completions. - if _, n := q.Dequeue(nil); n != 0 { - t.Fatalf("Packet reported as received too soon") - } - - // Post a completion. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - // Check that completion is properly reported. - bufs, n := q.Dequeue(nil) - if n != 100 { - t.Fatalf("Bad packet size: got %v, want %v", n, 100) - } - - if !reflect.DeepEqual(bufs, b) { - t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b) - } -} - -func TestBadTxCompletion(t *testing.T) { - // Check that tx completions with bad sizes are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion that is too long, and check that it is ignored. - if d := txp.Push(10); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } -} - -func TestBadRxCompletion(t *testing.T) { - // Check that bad rx completions are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes add up to less than the total - // size. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 10, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 10, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes will cause a 32-bit overflow, - // but adds up to the right number. - d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 255, 255, 255, 255, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 101, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } -} - -func TestFillTxPipe(t *testing.T) { - // Check that transmitting a new buffer when the buffer pipe is full - // fails gracefully. - pb1 := make([]byte, 104) - pb2 := make([]byte, 104) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Transmit twice, which should fill the tx pipe. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - for i := uint64(0); i < 2; i++ { - if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) { - t.Fatalf("Failed to transmit buffer") - } - } - - // Transmit another packet now that the tx pipe is full. - if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue succeeded when tx pipe is full") - } -} - -func TestFillRxPipe(t *testing.T) { - // Check that posting a new buffer when the buffer pipe is full fails - // gracefully. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a buffer twice, it should fill the tx pipe. - b := []RxBuffer{ - {100, 60, 1077, 0}, - } - - for i := 0; i < 2; i++ { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - } - - // Post another buffer now that the tx pipe is full. - if q.PostBuffers(b) { - t.Fatalf("PostBuffers succeeded on full queue") - } -} - -func TestLotsOfTransmissions(t *testing.T) { - // Make sure pipes are being properly flushed when transmitting packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Prepare packet with two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - - // Post 100000 packets and completions. - for i := 100000; i > 0; i-- { - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - rxp.Flush() - - d := txp.Push(8) - if d == nil { - t.Fatalf("Unable to write to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - if _, ok := q.CompletedPacket(); !ok { - t.Fatalf("Completion not returned") - } - } -} - -func TestLotsOfReceptions(t *testing.T) { - // Make sure pipes are being properly flushed when receiving packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Prepare for posting two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - // Post 100000 buffers and completions. - for i := 100000; i > 0; i-- { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - if _, n := q.Dequeue(nil); n == 0 { - t.Fatalf("Dequeue failed when there is a completion") - } - } -} - -func TestRxEnableNotification(t *testing.T) { - // Check that enabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.EnableNotification() - if state != eventFDEnabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled) - } -} - -func TestRxDisableNotification(t *testing.T) { - // Check that disabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.DisableNotification() - if state != eventFDDisabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled) - } -} diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go deleted file mode 100644 index 696e6c9e5..000000000 --- a/pkg/tcpip/link/sharedmem/queue/rx.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package queue provides the implementation of transmit and receive queues -// based on shared memory ring buffers. -package queue - -import ( - "encoding/binary" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -const ( - // Offsets within a posted buffer. - postedOffset = 0 - postedSize = 8 - postedRemainingInGroup = 12 - postedUserData = 16 - postedID = 24 - - sizeOfPostedBuffer = 32 - - // Offsets within a received packet header. - consumedPacketSize = 0 - consumedPacketReserved = 4 - - sizeOfConsumedPacketHeader = 8 - - // Offsets within a consumed buffer. - consumedOffset = 0 - consumedSize = 8 - consumedUserData = 12 - consumedID = 20 - - sizeOfConsumedBuffer = 28 - - // The following are the allowed states of the shared data area. - eventFDUninitialized = 0 - eventFDDisabled = 1 - eventFDEnabled = 2 -) - -// RxBuffer is the descriptor of a receive buffer. -type RxBuffer struct { - Offset uint64 - Size uint32 - ID uint64 - UserData uint64 -} - -// Rx is a receive queue. It is implemented with one tx and one rx pipe: the tx -// pipe is used to "post" buffers, while the rx pipe is used to receive packets -// whose contents have been written to previously posted buffers. -// -// This struct is thread-compatible. -type Rx struct { - tx pipe.Tx - rx pipe.Rx - sharedEventFDState *uint32 -} - -// Init initializes the receive queue with the given pipes, and shared state -// pointer -- the latter is used to enable/disable eventfd notifications. -func (r *Rx) Init(tx, rx []byte, sharedEventFDState *uint32) { - r.sharedEventFDState = sharedEventFDState - r.tx.Init(tx) - r.rx.Init(rx) -} - -// EnableNotification updates the shared state such that the peer will notify -// the eventfd when there are packets to be dequeued. -func (r *Rx) EnableNotification() { - atomic.StoreUint32(r.sharedEventFDState, eventFDEnabled) -} - -// DisableNotification updates the shared state such that the peer will not -// notify the eventfd. -func (r *Rx) DisableNotification() { - atomic.StoreUint32(r.sharedEventFDState, eventFDDisabled) -} - -// PostedBuffersLimit returns the maximum number of buffers that can be posted -// before the tx queue fills up. -func (r *Rx) PostedBuffersLimit() uint64 { - return r.tx.Capacity(sizeOfPostedBuffer) -} - -// PostBuffers makes the given buffers available for receiving data from the -// peer. Once they are posted, the peer is free to write to them and will -// eventually post them back for consumption. -func (r *Rx) PostBuffers(buffers []RxBuffer) bool { - for i := range buffers { - b := r.tx.Push(sizeOfPostedBuffer) - if b == nil { - r.tx.Abort() - return false - } - - pb := &buffers[i] - binary.LittleEndian.PutUint64(b[postedOffset:], pb.Offset) - binary.LittleEndian.PutUint32(b[postedSize:], pb.Size) - binary.LittleEndian.PutUint32(b[postedRemainingInGroup:], 0) - binary.LittleEndian.PutUint64(b[postedUserData:], pb.UserData) - binary.LittleEndian.PutUint64(b[postedID:], pb.ID) - } - - r.tx.Flush() - - return true -} - -// Dequeue receives buffers that have been previously posted by PostBuffers() -// and that have been filled by the peer and posted back. -// -// This is similar to append() in that new buffers are appended to "bufs", with -// reallocation only if "bufs" doesn't have enough capacity. -func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) { - for { - outBufs := bufs - - // Pull the next descriptor from the rx pipe. - b := r.rx.Pull() - if b == nil { - return bufs, 0 - } - - if len(b) < sizeOfConsumedPacketHeader { - log.Warningf("Ignoring packet header: size (%v) is less than header size (%v)", len(b), sizeOfConsumedPacketHeader) - r.rx.Flush() - continue - } - - totalDataSize := binary.LittleEndian.Uint32(b[consumedPacketSize:]) - - // Calculate the number of buffer descriptors and copy them - // over to the output. - count := (len(b) - sizeOfConsumedPacketHeader) / sizeOfConsumedBuffer - offset := sizeOfConsumedPacketHeader - buffersSize := uint32(0) - for i := count; i > 0; i-- { - s := binary.LittleEndian.Uint32(b[offset+consumedSize:]) - buffersSize += s - if buffersSize < s { - // The buffer size overflows an unsigned 32-bit - // integer, so break out and force it to be - // ignored. - totalDataSize = 1 - buffersSize = 0 - break - } - - outBufs = append(outBufs, RxBuffer{ - Offset: binary.LittleEndian.Uint64(b[offset+consumedOffset:]), - Size: s, - ID: binary.LittleEndian.Uint64(b[offset+consumedID:]), - }) - - offset += sizeOfConsumedBuffer - } - - r.rx.Flush() - - if buffersSize < totalDataSize { - // The descriptor is corrupted, ignore it. - log.Warningf("Ignoring packet: actual data size (%v) less than expected size (%v)", buffersSize, totalDataSize) - continue - } - - return outBufs, totalDataSize - } -} - -// Bytes returns the byte slices on which the queue operates. -func (r *Rx) Bytes() (tx, rx []byte) { - return r.tx.Bytes(), r.rx.Bytes() -} - -// DecodeRxBufferHeader decodes the header of a buffer posted on an rx queue. -func DecodeRxBufferHeader(b []byte) RxBuffer { - return RxBuffer{ - Offset: binary.LittleEndian.Uint64(b[postedOffset:]), - Size: binary.LittleEndian.Uint32(b[postedSize:]), - ID: binary.LittleEndian.Uint64(b[postedID:]), - UserData: binary.LittleEndian.Uint64(b[postedUserData:]), - } -} - -// RxCompletionSize returns the number of bytes needed to encode an rx -// completion containing "count" buffers. -func RxCompletionSize(count int) uint64 { - return sizeOfConsumedPacketHeader + uint64(count)*sizeOfConsumedBuffer -} - -// EncodeRxCompletion encodes an rx completion header. -func EncodeRxCompletion(b []byte, size, reserved uint32) { - binary.LittleEndian.PutUint32(b[consumedPacketSize:], size) - binary.LittleEndian.PutUint32(b[consumedPacketReserved:], reserved) -} - -// EncodeRxCompletionBuffer encodes the i-th rx completion buffer header. -func EncodeRxCompletionBuffer(b []byte, i int, rxb RxBuffer) { - b = b[RxCompletionSize(i):] - binary.LittleEndian.PutUint64(b[consumedOffset:], rxb.Offset) - binary.LittleEndian.PutUint32(b[consumedSize:], rxb.Size) - binary.LittleEndian.PutUint64(b[consumedUserData:], rxb.UserData) - binary.LittleEndian.PutUint64(b[consumedID:], rxb.ID) -} diff --git a/pkg/tcpip/link/sharedmem/queue/tx.go b/pkg/tcpip/link/sharedmem/queue/tx.go deleted file mode 100644 index beffe807b..000000000 --- a/pkg/tcpip/link/sharedmem/queue/tx.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package queue - -import ( - "encoding/binary" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -const ( - // Offsets within a packet header. - packetID = 0 - packetSize = 8 - packetReserved = 12 - - sizeOfPacketHeader = 16 - - // Offsets with a buffer descriptor - bufferOffset = 0 - bufferSize = 8 - - sizeOfBufferDescriptor = 12 -) - -// TxBuffer is the descriptor of a transmit buffer. -type TxBuffer struct { - Next *TxBuffer - Offset uint64 - Size uint32 -} - -// Tx is a transmit queue. It is implemented with one tx and one rx pipe: the -// tx pipe is used to request the transmission of packets, while the rx pipe -// is used to receive which transmissions have completed. -// -// This struct is thread-compatible. -type Tx struct { - tx pipe.Tx - rx pipe.Rx -} - -// Init initializes the transmit queue with the given pipes. -func (t *Tx) Init(tx, rx []byte) { - t.tx.Init(tx) - t.rx.Init(rx) -} - -// Enqueue queues the given linked list of buffers for transmission as one -// packet. While it is queued, the caller must not modify them. -func (t *Tx) Enqueue(id uint64, totalDataLen, bufferCount uint32, buffer *TxBuffer) bool { - // Reserve room in the tx pipe. - totalLen := sizeOfPacketHeader + uint64(bufferCount)*sizeOfBufferDescriptor - - b := t.tx.Push(totalLen) - if b == nil { - return false - } - - // Initialize the packet and buffer descriptors. - binary.LittleEndian.PutUint64(b[packetID:], id) - binary.LittleEndian.PutUint32(b[packetSize:], totalDataLen) - binary.LittleEndian.PutUint32(b[packetReserved:], 0) - - offset := sizeOfPacketHeader - for i := bufferCount; i != 0; i-- { - binary.LittleEndian.PutUint64(b[offset+bufferOffset:], buffer.Offset) - binary.LittleEndian.PutUint32(b[offset+bufferSize:], buffer.Size) - offset += sizeOfBufferDescriptor - buffer = buffer.Next - } - - t.tx.Flush() - - return true -} - -// CompletedPacket returns the id of the last completed transmission. The -// returned id, if any, refers to a value passed on a previous call to -// Enqueue(). -func (t *Tx) CompletedPacket() (id uint64, ok bool) { - for { - b := t.rx.Pull() - if b == nil { - return 0, false - } - - if len(b) != 8 { - t.rx.Flush() - log.Warningf("Ignoring completed packet: size (%v) is less than expected (%v)", len(b), 8) - continue - } - - v := binary.LittleEndian.Uint64(b) - - t.rx.Flush() - - return v, true - } -} - -// Bytes returns the byte slices on which the queue operates. -func (t *Tx) Bytes() (tx, rx []byte) { - return t.tx.Bytes(), t.rx.Bytes() -} - -// TxPacketInfo holds information about a packet sent on a tx queue. -type TxPacketInfo struct { - ID uint64 - Size uint32 - Reserved uint32 - BufferCount int -} - -// DecodeTxPacketHeader decodes the header of a packet sent over a tx queue. -func DecodeTxPacketHeader(b []byte) TxPacketInfo { - return TxPacketInfo{ - ID: binary.LittleEndian.Uint64(b[packetID:]), - Size: binary.LittleEndian.Uint32(b[packetSize:]), - Reserved: binary.LittleEndian.Uint32(b[packetReserved:]), - BufferCount: (len(b) - sizeOfPacketHeader) / sizeOfBufferDescriptor, - } -} - -// DecodeTxBufferHeader decodes the header of the i-th buffer of a packet sent -// over a tx queue. -func DecodeTxBufferHeader(b []byte, i int) TxBuffer { - b = b[sizeOfPacketHeader+i*sizeOfBufferDescriptor:] - return TxBuffer{ - Offset: binary.LittleEndian.Uint64(b[bufferOffset:]), - Size: binary.LittleEndian.Uint32(b[bufferSize:]), - } -} - -// EncodeTxCompletion encodes a tx completion header. -func EncodeTxCompletion(b []byte, id uint64) { - binary.LittleEndian.PutUint64(b, id) -} diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go deleted file mode 100644 index eec11e4cb..000000000 --- a/pkg/tcpip/link/sharedmem/rx.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package sharedmem - -import ( - "sync/atomic" - "syscall" - - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" -) - -// rx holds all state associated with an rx queue. -type rx struct { - data []byte - sharedData []byte - q queue.Rx - eventFD int -} - -// init initializes all state needed by the rx queue based on the information -// provided. -// -// The caller always retains ownership of all file descriptors passed in. The -// queue implementation will duplicate any that it may need in the future. -func (r *rx) init(mtu uint32, c *QueueConfig) error { - // Map in all buffers. - txPipe, err := getBuffer(c.TxPipeFD) - if err != nil { - return err - } - - rxPipe, err := getBuffer(c.RxPipeFD) - if err != nil { - syscall.Munmap(txPipe) - return err - } - - data, err := getBuffer(c.DataFD) - if err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - return err - } - - sharedData, err := getBuffer(c.SharedDataFD) - if err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - syscall.Munmap(data) - return err - } - - // Duplicate the eventFD so that caller can close it but we can still - // use it. - efd, err := syscall.Dup(c.EventFD) - if err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - syscall.Munmap(data) - syscall.Munmap(sharedData) - return err - } - - // Set the eventfd as non-blocking. - if err := syscall.SetNonblock(efd, true); err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - syscall.Munmap(data) - syscall.Munmap(sharedData) - syscall.Close(efd) - return err - } - - // Initialize state based on buffers. - r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData)) - r.data = data - r.eventFD = efd - r.sharedData = sharedData - - return nil -} - -// cleanup releases all resources allocated during init(). It must only be -// called if init() has previously succeeded. -func (r *rx) cleanup() { - a, b := r.q.Bytes() - syscall.Munmap(a) - syscall.Munmap(b) - - syscall.Munmap(r.data) - syscall.Munmap(r.sharedData) - syscall.Close(r.eventFD) -} - -// postAndReceive posts the provided buffers (if any), and then tries to read -// from the receive queue. -// -// Capacity permitting, it reuses the posted buffer slice to store the buffers -// that were read as well. -// -// This function will block if there aren't any available packets. -func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.RxBuffer, uint32) { - // Post the buffers first. If we cannot post, sleep until we can. We - // never post more than will fit concurrently, so it's safe to wait - // until enough room is available. - if len(b) != 0 && !r.q.PostBuffers(b) { - r.q.EnableNotification() - for !r.q.PostBuffers(b) { - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) - if atomic.LoadUint32(stopRequested) != 0 { - r.q.DisableNotification() - return nil, 0 - } - } - r.q.DisableNotification() - } - - // Read the next set of descriptors. - b, n := r.q.Dequeue(b[:0]) - if len(b) != 0 { - return b, n - } - - // Data isn't immediately available. Enable eventfd notifications. - r.q.EnableNotification() - for { - b, n = r.q.Dequeue(b) - if len(b) != 0 { - break - } - - // Wait for notification. - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) - if atomic.LoadUint32(stopRequested) != 0 { - r.q.DisableNotification() - return nil, 0 - } - } - r.q.DisableNotification() - - return b, n -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go deleted file mode 100644 index 080f9d667..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// Package sharedmem provides the implemention of data-link layer endpoints -// backed by shared memory. -// -// Shared memory endpoints can be used in the networking stack by calling New() -// to create a new endpoint, and then passing it as an argument to -// Stack.CreateNIC(). -package sharedmem - -import ( - "sync" - "sync/atomic" - "syscall" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// QueueConfig holds all the file descriptors needed to describe a tx or rx -// queue over shared memory. It is used when creating new shared memory -// endpoints to describe tx and rx queues. -type QueueConfig struct { - // DataFD is a file descriptor for the file that contains the data to - // be transmitted via this queue. Descriptors contain offsets within - // this file. - DataFD int - - // EventFD is a file descriptor for the event that is signaled when - // data is becomes available in this queue. - EventFD int - - // TxPipeFD is a file descriptor for the tx pipe associated with the - // queue. - TxPipeFD int - - // RxPipeFD is a file descriptor for the rx pipe associated with the - // queue. - RxPipeFD int - - // SharedDataFD is a file descriptor for the file that contains shared - // state between the two ends of the queue. This data specifies, for - // example, whether EventFD signaling is enabled or disabled. - SharedDataFD int -} - -type endpoint struct { - // mtu (maximum transmission unit) is the maximum size of a packet. - mtu uint32 - - // bufferSize is the size of each individual buffer. - bufferSize uint32 - - // addr is the local address of this endpoint. - addr tcpip.LinkAddress - - // rx is the receive queue. - rx rx - - // stopRequested is to be accessed atomically only, and determines if - // the worker goroutines should stop. - stopRequested uint32 - - // Wait group used to indicate that all workers have stopped. - completed sync.WaitGroup - - // mu protects the following fields. - mu sync.Mutex - - // tx is the transmit queue. - tx tx - - // workerStarted specifies whether the worker goroutine was started. - workerStarted bool -} - -// New creates a new shared-memory-based endpoint. Buffers will be broken up -// into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { - e := &endpoint{ - mtu: mtu, - bufferSize: bufferSize, - addr: addr, - } - - if err := e.tx.init(bufferSize, &tx); err != nil { - return nil, err - } - - if err := e.rx.init(bufferSize, &rx); err != nil { - e.tx.cleanup() - return nil, err - } - - return e, nil -} - -// Close frees all resources associated with the endpoint. -func (e *endpoint) Close() { - // Tell dispatch goroutine to stop, then write to the eventfd so that - // it wakes up in case it's sleeping. - atomic.StoreUint32(&e.stopRequested, 1) - syscall.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Cleanup the queues inline if the worker hasn't started yet; we also - // know it won't start from now on because stopRequested is set to 1. - e.mu.Lock() - workerPresent := e.workerStarted - e.mu.Unlock() - - if !workerPresent { - e.tx.cleanup() - e.rx.cleanup() - } -} - -// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have -// stopped after a Close() call. -func (e *endpoint) Wait() { - e.completed.Wait() -} - -// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that -// reads packets from the rx queue. -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.mu.Lock() - if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { - e.workerStarted = true - e.completed.Add(1) - // Link endpoints are not savable. When transportation endpoints - // are saved, they stop sending outgoing packets and all - // incoming packets are rejected. - go e.dispatchLoop(dispatcher) // S/R-SAFE: see above. - } - e.mu.Unlock() -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *endpoint) IsAttached() bool { - e.mu.Lock() - defer e.mu.Unlock() - return e.workerStarted -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *endpoint) MTU() uint32 { - return e.mtu - header.EthernetMinimumSize -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the -// ethernet frame header size. -func (*endpoint) MaxHeaderLength() uint16 { - return header.EthernetMinimumSize -} - -// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local -// link address. -func (e *endpoint) LinkAddress() tcpip.LinkAddress { - return e.addr -} - -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - // Add the ethernet header here. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) - ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, - } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - - v := pkt.Data.ToView() - // Transmit the packet. - e.mu.Lock() - ok := e.tx.transmit(pkt.Header.View(), v) - e.mu.Unlock() - - if !ok { - return tcpip.ErrWouldBlock - } - - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - panic("not implemented") -} - -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - v := vv.ToView() - // Transmit the packet. - e.mu.Lock() - ok := e.tx.transmit(v, buffer.View{}) - e.mu.Unlock() - - if !ok { - return tcpip.ErrWouldBlock - } - - return nil -} - -// dispatchLoop reads packets from the rx queue in a loop and dispatches them -// to the network stack. -func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { - // Post initial set of buffers. - limit := e.rx.q.PostedBuffersLimit() - if l := uint64(len(e.rx.data)) / uint64(e.bufferSize); limit > l { - limit = l - } - for i := uint64(0); i < limit; i++ { - b := queue.RxBuffer{ - Offset: i * uint64(e.bufferSize), - Size: e.bufferSize, - ID: i, - } - if !e.rx.q.PostBuffers([]queue.RxBuffer{b}) { - log.Warningf("Unable to post %v-th buffer", i) - } - } - - // Read in a loop until a stop is requested. - var rxb []queue.RxBuffer - for atomic.LoadUint32(&e.stopRequested) == 0 { - var n uint32 - rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested) - - // Copy data from the shared area to its own buffer, then - // prepare to repost the buffer. - b := make([]byte, n) - offset := uint32(0) - for i := range rxb { - copy(b[offset:], e.rx.data[rxb[i].Offset:][:rxb[i].Size]) - offset += rxb[i].Size - - rxb[i].Size = e.bufferSize - } - - if n < header.EthernetMinimumSize { - continue - } - - // Send packet up the stack. - eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ - Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), - LinkHeader: buffer.View(eth), - }) - } - - // Clean state. - e.tx.cleanup() - e.rx.cleanup() - - e.completed.Done() -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go deleted file mode 100644 index 89603c48f..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ /dev/null @@ -1,812 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package sharedmem - -import ( - "bytes" - "io/ioutil" - "math/rand" - "os" - "strings" - "sync" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - localLinkAddr = "\xde\xad\xbe\xef\x56\x78" - remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" - - queueDataSize = 1024 * 1024 - queuePipeSize = 4096 -) - -type queueBuffers struct { - data []byte - rx pipe.Tx - tx pipe.Rx -} - -func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) { - // Prepare tx pipe. - b, err := getBuffer(c.TxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.tx.Init(b) - - // Prepare rx pipe. - b, err = getBuffer(c.RxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.rx.Init(b) - - // Get data slice. - q.data, err = getBuffer(c.DataFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } -} - -func (q *queueBuffers) cleanup() { - syscall.Munmap(q.tx.Bytes()) - syscall.Munmap(q.rx.Bytes()) - syscall.Munmap(q.data) -} - -type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView - linkHeader buffer.View -} - -type testContext struct { - t *testing.T - ep *endpoint - txCfg QueueConfig - rxCfg QueueConfig - txq queueBuffers - rxq queueBuffers - - packetCh chan struct{} - mu sync.Mutex - packets []packetInfo -} - -func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext { - var err error - c := &testContext{ - t: t, - packetCh: make(chan struct{}, 1000000), - } - c.txCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - c.rxCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - initQueue(t, &c.txq, &c.txCfg) - initQueue(t, &c.rxq, &c.rxCfg) - - ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - c.ep = ep.(*endpoint) - c.ep.Attach(c) - - return c -} - -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - c.mu.Lock() - c.packets = append(c.packets, packetInfo{ - addr: remoteLinkAddr, - proto: proto, - vv: pkt.Data.Clone(nil), - }) - c.mu.Unlock() - - c.packetCh <- struct{}{} -} - -func (c *testContext) cleanup() { - c.ep.Close() - closeFDs(&c.txCfg) - closeFDs(&c.rxCfg) - c.txq.cleanup() - c.rxq.cleanup() -} - -func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) { - for i := 0; i < n; i++ { - select { - case <-c.packetCh: - case <-to: - c.t.Fatalf(errorStr) - } - } -} - -func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) { - b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs))) - queue.EncodeRxCompletion(b, size, 0) - for i := range bs { - queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{ - Offset: bs[i].Offset, - Size: bs[i].Size, - ID: bs[i].ID, - }) - } -} - -func randomFill(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func shuffle(b []int) { - for i := len(b) - 1; i >= 0; i-- { - j := rand.Intn(i + 1) - b[i], b[j] = b[j], b[i] - } -} - -func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir := os.Getenv("TEST_TMPDIR") - if tmpDir == "" { - tmpDir = os.Getenv("TMPDIR") - } - f, err := ioutil.TempFile(tmpDir, "sharedmem_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - syscall.Unlink(f.Name()) - - if initQueue { - // Write the "slot-free" flag in the initial queue. - _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) - if err != nil { - t.Fatalf("WriteAt failed: %v", err) - } - } - - fd, err := syscall.Dup(int(f.Fd())) - if err != nil { - t.Fatalf("Dup failed: %v", err) - } - - if err := syscall.Ftruncate(fd, size); err != nil { - syscall.Close(fd) - t.Fatalf("Ftruncate failed: %v", err) - } - - return fd -} - -func closeFDs(c *QueueConfig) { - syscall.Close(c.DataFD) - syscall.Close(c.EventFD) - syscall.Close(c.TxPipeFD) - syscall.Close(c.RxPipeFD) - syscall.Close(c.SharedDataFD) -} - -type queueSizes struct { - dataSize int64 - txPipeSize int64 - rxPipeSize int64 - sharedDataSize int64 -} - -func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { - fd, _, err := syscall.RawSyscall(syscall.SYS_EVENTFD2, 0, 0, 0) - if err != 0 { - t.Fatalf("eventfd failed: %v", error(err)) - } - - return QueueConfig{ - EventFD: int(fd), - DataFD: createFile(t, s.dataSize, false), - TxPipeFD: createFile(t, s.txPipeSize, true), - RxPipeFD: createFile(t, s.rxPipeSize, true), - SharedDataFD: createFile(t, s.sharedDataSize, false), - } -} - -// TestSimpleSend sends 1000 packets with random header and payload sizes, -// then checks that the right payload is received on the shared memory queues. -func TestSimpleSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - for iters := 1000; iters > 0; iters-- { - func() { - // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) - randomFill(hdrBuf) - - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check the ethernet header. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: localLinkAddr, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } - - // Compare contents skipping the ethernet header added by the - // endpoint. - merged := append(hdrBuf, buf...) - if uint32(len(contents)) < pi.Size { - t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) - } - contents = contents[:pi.Size][header.EthernetMinimumSize:] - - if !bytes.Equal(contents, merged) { - t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged)) - } - }() - } -} - -// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress -// set in Route (using much of the same code as TestSimpleSend), then checks -// that the encoded ethernet header received includes the correct SrcAddr. -func TestPreserveSrcAddressInSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) - // Set both remote and local link address in route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - LocalLinkAddress: newLocalLinkAddress, - } - - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ - Header: hdr, - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check that the ethernet header contains the expected SrcAddr. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: newLocalLinkAddress, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } -} - -// TestFillTxQueue sends packets until the queue is full. -func TestFillTxQueue(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets -// until the queue is full. -func TestFillTxQueueAfterBadCompletion(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Send a bad completion. - queue.EncodeTxCompletion(c.txq.rx.Push(8), 1) - c.txq.rx.Flush() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Send two packets so that the id slice has at least two slots. - for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } - - // Complete the two writes twice. - for i := 2; i > 0; i-- { - pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull()) - - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - c.txq.rx.Flush() - } - c.txq.tx.Flush() - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxMemory sends packets until the we run out of shared memory. -func TestFillTxMemory(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible until - // we fill the memory. - ids := make(map[uint64]struct{}) - for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - c.txq.tx.Flush() - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }) - if want := tcpip.ErrWouldBlock; err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of -// shared memory for a 2-buffer packet, but still with room for a 1-buffer -// packet. -func TestFillTxMemoryWithMultiBuffer(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible - // until there is only one buffer left. - for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Pull the posted buffer. - c.txq.tx.Pull() - c.txq.tx.Flush() - } - - // Attempt to write a two-buffer packet. It must fail. - { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: uu, - }); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } - } - - // Attempt to write the one-buffer packet again. It must succeed. - { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } -} - -func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte { - t.Helper() - - for { - b := p.Pull() - if b != nil { - return b - } - - select { - case <-time.After(10 * time.Millisecond): - case <-to: - t.Fatal(errStr) - } - } -} - -// TestSimpleReceive completes 1000 different receives with random payload and -// random number of buffers. It checks that the contents match the expected -// values. -func TestSimpleReceive(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Check that buffers have been posted. - limit := c.ep.rx.q.PostedBuffersLimit() - for i := uint64(0); i < limit; i++ { - timeout := time.After(2 * time.Second) - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted")) - - if want := i * bufferSize; want != bi.Offset { - t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want) - } - - if want := i; want != bi.ID { - t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want) - } - - if bufferSize != bi.Size { - t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize) - } - } - c.rxq.tx.Flush() - - // Create a slice with the indices 0..limit-1. - idx := make([]int, limit) - for i := range idx { - idx[i] = i - } - - // Complete random packets 1000 times. - for iters := 1000; iters > 0; iters-- { - timeout := time.After(2 * time.Second) - // Prepare a random packet. - shuffle(idx) - n := 1 + rand.Intn(10) - bufs := make([]queue.RxBuffer, n) - contents := make([]byte, bufferSize*n-rand.Intn(500)) - randomFill(contents) - for i := range bufs { - j := idx[i] - bufs[i].Size = bufferSize - bufs[i].Offset = uint64(bufferSize * j) - bufs[i].ID = uint64(j) - - copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:]) - } - - // Push completion. - c.pushRxCompletion(uint32(len(contents)), bufs) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be received, then check it. - c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") - c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) - c.packets = c.packets[:0] - c.mu.Unlock() - - if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) { - t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents) - } - - // Check that buffers have been reposted. - for i := range bufs { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted")) - if bi != bufs[i] { - t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i]) - } - } - c.rxq.tx.Flush() - } -} - -// TestRxBuffersReposted tests that rx buffers get reposted after they have been -// completed. -func TestRxBuffersReposted(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Receive all posted buffers. - limit := c.ep.rx.q.PostedBuffersLimit() - buffers := make([]queue.RxBuffer, 0, limit) - for i := limit; i > 0; i-- { - timeout := time.After(2 * time.Second) - buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers"))) - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when individually completed. - for i := range buffers { - timeout := time.After(2 * time.Second) - // Complete the buffer. - c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for it to be reposted. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[i] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i]) - } - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when completed in pairs. - for i := 0; i < len(buffers)/2; i++ { - timeout := time.After(2 * time.Second) - // Complete with two buffers. - c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for them to be reposted. - for j := 0; j < 2; j++ { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[2*i+j] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j]) - } - } - } - c.rxq.tx.Flush() -} - -// TestReceivePostingIsFull checks that the endpoint will properly handle the -// case when a received buffer cannot be immediately reposted because it hasn't -// been pulled from the tx pipe yet. -func TestReceivePostingIsFull(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Complete first posted buffer before flushing it from the tx pipe. - first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) - c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that packet is received. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Complete another buffer. - second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) - c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that no packet is received yet, as the worker is blocked trying - // to repost. - select { - case <-time.After(500 * time.Millisecond): - case <-c.packetCh: - t.Fatalf("Unexpected packet received") - } - - // Flush tx queue, which will allow the first buffer to be reposted, - // and the second completion to be pulled. - c.rxq.tx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that second packet completes. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") -} - -// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to -// repost a buffer. Make sure it backs out. -func TestCloseWhileWaitingToPost(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - cleaned := false - defer func() { - if !cleaned { - c.cleanup() - } - }() - - // Complete first posted buffer before flushing it from the tx pipe. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) - c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be indicated. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Cleanup and wait for worker to complete. - c.cleanup() - cleaned = true - c.ep.Wait() -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go deleted file mode 100644 index f7e816a41..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sharedmem - -import ( - "unsafe" -) - -// sharedDataPointer converts the shared data slice into a pointer so that it -// can be used in atomic operations. -func sharedDataPointer(sharedData []byte) *uint32 { - return (*uint32)(unsafe.Pointer(&sharedData[0:4][0])) -} diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go deleted file mode 100644 index 6b8d7859d..000000000 --- a/pkg/tcpip/link/sharedmem/tx.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sharedmem - -import ( - "math" - "syscall" - - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" -) - -const ( - nilID = math.MaxUint64 -) - -// tx holds all state associated with a tx queue. -type tx struct { - data []byte - q queue.Tx - ids idManager - bufs bufferManager -} - -// init initializes all state needed by the tx queue based on the information -// provided. -// -// The caller always retains ownership of all file descriptors passed in. The -// queue implementation will duplicate any that it may need in the future. -func (t *tx) init(mtu uint32, c *QueueConfig) error { - // Map in all buffers. - txPipe, err := getBuffer(c.TxPipeFD) - if err != nil { - return err - } - - rxPipe, err := getBuffer(c.RxPipeFD) - if err != nil { - syscall.Munmap(txPipe) - return err - } - - data, err := getBuffer(c.DataFD) - if err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - return err - } - - // Initialize state based on buffers. - t.q.Init(txPipe, rxPipe) - t.ids.init() - t.bufs.init(0, len(data), int(mtu)) - t.data = data - - return nil -} - -// cleanup releases all resources allocated during init(). It must only be -// called if init() has previously succeeded. -func (t *tx) cleanup() { - a, b := t.q.Bytes() - syscall.Munmap(a) - syscall.Munmap(b) - syscall.Munmap(t.data) -} - -// transmit sends a packet made up of up to two buffers. Returns a boolean that -// specifies whether the packet was successfully transmitted. -func (t *tx) transmit(a, b []byte) bool { - // Pull completions from the tx queue and add their buffers back to the - // pool so that we can reuse them. - for { - id, ok := t.q.CompletedPacket() - if !ok { - break - } - - if buf := t.ids.remove(id); buf != nil { - t.bufs.free(buf) - } - } - - bSize := t.bufs.entrySize - total := uint32(len(a) + len(b)) - bufCount := (total + bSize - 1) / bSize - - // Allocate enough buffers to hold all the data. - var buf *queue.TxBuffer - for i := bufCount; i != 0; i-- { - b := t.bufs.alloc() - if b == nil { - // Failed to get all buffers. Return to the pool - // whatever we had managed to get. - if buf != nil { - t.bufs.free(buf) - } - return false - } - b.Next = buf - buf = b - } - - // Copy data into allocated buffers. - nBuf := buf - var dBuf []byte - for _, data := range [][]byte{a, b} { - for len(data) > 0 { - if len(dBuf) == 0 { - dBuf = t.data[nBuf.Offset:][:nBuf.Size] - nBuf = nBuf.Next - } - n := copy(dBuf, data) - data = data[n:] - dBuf = dBuf[n:] - } - } - - // Get an id for this packet and send it out. - id := t.ids.add(buf) - if !t.q.Enqueue(id, total, bufCount, buf) { - t.ids.remove(id) - t.bufs.free(buf) - return false - } - - return true -} - -// getBuffer returns a memory region mapped to the full contents of the given -// file descriptor. -func getBuffer(fd int) ([]byte, error) { - var s syscall.Stat_t - if err := syscall.Fstat(fd, &s); err != nil { - return nil, err - } - - // Check that size doesn't overflow an int. - if s.Size > int64(^uint(0)>>1) { - return nil, syscall.EDOM - } - - return syscall.Mmap(fd, 0, int(s.Size), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED|syscall.MAP_FILE) -} - -// idDescriptor is used by idManager to either point to a tx buffer (in case -// the ID is assigned) or to the next free element (if the id is not assigned). -type idDescriptor struct { - buf *queue.TxBuffer - nextFree uint64 -} - -// idManager is a manager of tx buffer identifiers. It assigns unique IDs to -// tx buffers that are added to it; the IDs can only be reused after they have -// been removed. -// -// The ID assignments are stored so that the tx buffers can be retrieved from -// the IDs previously assigned to them. -type idManager struct { - // ids is a slice containing all tx buffers. The ID is the index into - // this slice. - ids []idDescriptor - - // freeList a list of free IDs. - freeList uint64 -} - -// init initializes the id manager. -func (m *idManager) init() { - m.freeList = nilID -} - -// add assigns an ID to the given tx buffer. -func (m *idManager) add(b *queue.TxBuffer) uint64 { - if i := m.freeList; i != nilID { - // There is an id available in the free list, just use it. - m.ids[i].buf = b - m.freeList = m.ids[i].nextFree - return i - } - - // We need to expand the id descriptor. - m.ids = append(m.ids, idDescriptor{buf: b}) - return uint64(len(m.ids) - 1) -} - -// remove retrieves the tx buffer associated with the given ID, and removes the -// ID from the assigned table so that it can be reused in the future. -func (m *idManager) remove(i uint64) *queue.TxBuffer { - if i >= uint64(len(m.ids)) { - return nil - } - - desc := &m.ids[i] - b := desc.buf - if b == nil { - // The provided id is not currently assigned. - return nil - } - - desc.buf = nil - desc.nextFree = m.freeList - m.freeList = i - - return b -} - -// bufferManager manages a buffer region broken up into smaller, equally sized -// buffers. Smaller buffers can be allocated and freed. -type bufferManager struct { - freeList *queue.TxBuffer - curOffset uint64 - limit uint64 - entrySize uint32 -} - -// init initializes the buffer manager. -func (b *bufferManager) init(initialOffset, size, entrySize int) { - b.freeList = nil - b.curOffset = uint64(initialOffset) - b.limit = uint64(initialOffset + size/entrySize*entrySize) - b.entrySize = uint32(entrySize) -} - -// alloc allocates a buffer from the manager, if one is available. -func (b *bufferManager) alloc() *queue.TxBuffer { - if b.freeList != nil { - // There is a descriptor ready for reuse in the free list. - d := b.freeList - b.freeList = d.Next - d.Next = nil - return d - } - - if b.curOffset < b.limit { - // There is room available in the never-used range, so create - // a new descriptor for it. - d := &queue.TxBuffer{ - Offset: b.curOffset, - Size: b.entrySize, - } - b.curOffset += uint64(b.entrySize) - return d - } - - return nil -} - -// free returns all buffers in the list to the buffer manager so that they can -// be reused. -func (b *bufferManager) free(d *queue.TxBuffer) { - // Find the last buffer in the list. - last := d - for last.Next != nil { - last = last.Next - } - - // Push list onto free list. - last.Next = b.freeList - b.freeList = d -} diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD deleted file mode 100644 index d6ae0368a..000000000 --- a/pkg/tcpip/link/sniffer/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sniffer", - srcs = [ - "pcap.go", - "sniffer.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go new file mode 100755 index 000000000..cfd84a739 --- /dev/null +++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package sniffer + diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD deleted file mode 100644 index a71a493fc..000000000 --- a/pkg/tcpip/link/tun/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "tun", - srcs = ["tun_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", - visibility = ["//visibility:public"], -) diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go deleted file mode 100644 index 09ca9b527..000000000 --- a/pkg/tcpip/link/tun/tun_unsafe.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// Package tun contains methods to open TAP and TUN devices. -package tun - -import ( - "syscall" - "unsafe" -) - -// Open opens the specified TUN device, sets it to non-blocking mode, and -// returns its file descriptor. -func Open(name string) (int, error) { - return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI) -} - -// OpenTAP opens the specified TAP device, sets it to non-blocking mode, and -// returns its file descriptor. -func OpenTAP(name string) (int, error) { - return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI) -} - -func open(name string, flags uint16) (int, error) { - fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0) - if err != nil { - return -1, err - } - - var ifr struct { - name [16]byte - flags uint16 - _ [22]byte - } - - copy(ifr.name[:], name) - ifr.flags = flags - _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) - if errno != 0 { - syscall.Close(fd) - return -1, errno - } - - if err = syscall.SetNonblock(fd, true); err != nil { - syscall.Close(fd) - return -1, err - } - - return fd, nil -} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD deleted file mode 100644 index 134837943..000000000 --- a/pkg/tcpip/link/waitable/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "waitable", - srcs = [ - "waitable.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", - visibility = ["//visibility:public"], - deps = [ - "//pkg/gate", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "waitable_test", - srcs = [ - "waitable_test.go", - ], - embed = [":waitable"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go deleted file mode 100644 index a8de38979..000000000 --- a/pkg/tcpip/link/waitable/waitable.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package waitable provides the implementation of data-link layer endpoints -// that wrap other endpoints, and can wait for inflight calls to WritePacket or -// DeliverNetworkPacket to finish (and new ones to be prevented). -// -// Waitable endpoints can be used in the networking stack by calling New(eID) to -// create a new endpoint, where eID is the ID of the endpoint being wrapped, -// and then passing it as an argument to Stack.CreateNIC(). -package waitable - -import ( - "gvisor.dev/gvisor/pkg/gate" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// Endpoint is a waitable link-layer endpoint. -type Endpoint struct { - dispatchGate gate.Gate - dispatcher stack.NetworkDispatcher - - writeGate gate.Gate - lower stack.LinkEndpoint -} - -// New creates a new waitable link-layer endpoint. It wraps around another -// endpoint and allows the caller to block new write/dispatch calls and wait for -// the inflight ones to finish before returning. -func New(lower stack.LinkEndpoint) *Endpoint { - return &Endpoint{ - lower: lower, - } -} - -// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. -// It is called by the link-layer endpoint being wrapped when a packet arrives, -// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't -// been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - if !e.dispatchGate.Enter() { - return - } - - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) - e.dispatchGate.Leave() -} - -// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and -// registers with the lower endpoint as its dispatcher so that "e" is called -// for inbound packets. -func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - e.lower.Attach(e) -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *Endpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the -// lower endpoint. -func (e *Endpoint) MTU() uint32 { - return e.lower.MTU() -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the -// request to the lower endpoint. -func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.lower.Capabilities() -} - -// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just -// forwards the request to the lower endpoint. -func (e *Endpoint) MaxHeaderLength() uint16 { - return e.lower.MaxHeaderLength() -} - -// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the -// request to the lower endpoint. -func (e *Endpoint) LinkAddress() tcpip.LinkAddress { - return e.lower.LinkAddress() -} - -// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by -// higher-level protocols to write packets. It only forwards packets to the -// lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - if !e.writeGate.Enter() { - return nil - } - - err := e.lower.WritePacket(r, gso, protocol, pkt) - e.writeGate.Leave() - return err -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by -// higher-level protocols to write packets. It only forwards packets to the -// lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - if !e.writeGate.Enter() { - return len(pkts), nil - } - - n, err := e.lower.WritePackets(r, gso, pkts, protocol) - e.writeGate.Leave() - return n, err -} - -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if !e.writeGate.Enter() { - return nil - } - - err := e.lower.WriteRawPacket(vv) - e.writeGate.Leave() - return err -} - -// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, -// and waits for inflight ones to finish before returning. -func (e *Endpoint) WaitWrite() { - e.writeGate.Close() -} - -// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the -// actual dispatcher, and waits for inflight ones to finish before returning. -func (e *Endpoint) WaitDispatch() { - e.dispatchGate.Close() -} - -// Wait implements stack.LinkEndpoint.Wait. -func (e *Endpoint) Wait() {} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go deleted file mode 100644 index 31b11a27a..000000000 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waitable - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type countedEndpoint struct { - dispatchCount int - writeCount int - attachCount int - - mtu uint32 - capabilities stack.LinkEndpointCapabilities - hdrLen uint16 - linkAddr tcpip.LinkAddress - - dispatcher stack.NetworkDispatcher -} - -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - e.dispatchCount++ -} - -func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.attachCount++ - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *countedEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *countedEndpoint) MTU() uint32 { - return e.mtu -} - -func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.capabilities -} - -func (e *countedEndpoint) MaxHeaderLength() uint16 { - return e.hdrLen -} - -func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - e.writeCount++ - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - e.writeCount += len(pkts) - return len(pkts), nil -} - -func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - e.writeCount++ - return nil -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*countedEndpoint) Wait() {} - -func TestWaitWrite(t *testing.T) { - ep := &countedEndpoint{} - wep := New(ep) - - // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) - if want := 1; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on dispatches, then try to write. It must go through. - wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on writes, then try to write. It must not go through. - wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } -} - -func TestWaitDispatch(t *testing.T) { - ep := &countedEndpoint{} - wep := New(ep) - - // Check that attach happens. - wep.Attach(ep) - if want := 1; ep.attachCount != want { - t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want) - } - - // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) - if want := 1; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on writes, then try to dispatch. It must go through. - wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on dispatches, then try to dispatch. It must not go through. - wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } -} - -func TestOtherMethods(t *testing.T) { - const ( - mtu = 0xdead - capabilities = 0xbeef - hdrLen = 0x1234 - linkAddr = "test address" - ) - ep := &countedEndpoint{ - mtu: mtu, - capabilities: capabilities, - hdrLen: hdrLen, - linkAddr: linkAddr, - } - wep := New(ep) - - if v := wep.MTU(); v != mtu { - t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) - } - - if v := wep.Capabilities(); v != capabilities { - t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities) - } - - if v := wep.MaxHeaderLength(); v != hdrLen { - t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen) - } - - if v := wep.LinkAddress(); v != linkAddr { - t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr) - } -} diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD deleted file mode 100644 index 9d16ff8c9..000000000 --- a/pkg/tcpip/network/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "ip_test", - size = "small", - srcs = [ - "ip_test.go", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - ], -) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD deleted file mode 100644 index e7617229b..000000000 --- a/pkg/tcpip/network/arp/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "arp", - srcs = ["arp.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "arp_test", - size = "small", - srcs = ["arp_test.go"], - deps = [ - ":arp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - ], -) diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go new file mode 100755 index 000000000..14a21baff --- /dev/null +++ b/pkg/tcpip/network/arp/arp_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package arp + diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go deleted file mode 100644 index 8e6048a21..000000000 --- a/pkg/tcpip/network/arp/arp_test.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp_test - -import ( - "strconv" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -const ( - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") -) - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack -} - -func newTestContext(t *testing.T) *testContext { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, - }) - - const defaultMTU = 65536 - ep := channel.New(256, defaultMTU, stackLinkAddr) - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("AddAddress for arp failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }}) - - return &testContext{ - t: t, - s: s, - linkEP: ep, - } -} - -func (c *testContext) cleanup() { - close(c.linkEP.C) -} - -func TestDirectRequest(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - - inject := func(addr tcpip.Address) { - copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ - Data: v.ToVectorisedView(), - }) - } - - for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { - t.Run(strconv.Itoa(i), func(t *testing.T) { - inject(address) - pi := <-c.linkEP.C - if pi.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) - } - rep := header.ARP(pi.Pkt.Header.View()) - if !rep.IsValid() { - t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { - t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) - } - }) - } - - inject(stackAddrBad) - select { - case pkt := <-c.linkEP.C: - t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - case <-time.After(100 * time.Millisecond): - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. - } -} diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD deleted file mode 100644 index acf1e022c..000000000 --- a/pkg/tcpip/network/fragmentation/BUILD +++ /dev/null @@ -1,46 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_template_instance( - name = "reassembler_list", - out = "reassembler_list.go", - package = "fragmentation", - prefix = "reassembler", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*reassembler", - "Linker": "*reassembler", - }, -) - -go_library( - name = "fragmentation", - srcs = [ - "frag_heap.go", - "fragmentation.go", - "reassembler.go", - "reassembler_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - ], -) - -go_test( - name = "fragmentation_test", - size = "small", - srcs = [ - "frag_heap_test.go", - "fragmentation_test.go", - "reassembler_test.go", - ], - embed = [":fragmentation"], - deps = ["//pkg/tcpip/buffer"], -) diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go deleted file mode 100644 index 9ececcb9f..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -var reassambleTestCases = []struct { - comment string - in []fragment - want buffer.VectorisedView -}{ - { - comment: "Non-overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Non-overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Duplicated packets", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(1, "0"), - }, - { - comment: "Overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(2, "01")}, - {offset: 1, vv: vv(2, "12")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(2, "12")}, - {offset: 0, vv: vv(2, "01")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping subset in-order", - in: []fragment{ - {offset: 0, vv: vv(3, "012")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(3, "012"), - }, - { - comment: "Overlapping subset out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(3, "012")}, - }, - want: vv(3, "012"), - }, -} - -func TestReassamble(t *testing.T) { - for _, c := range reassambleTestCases { - t.Run(c.comment, func(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, err := h.reassemble() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, c.want) { - t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) - } - }) - } -} - -func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when the first packet had offset != 0") - } -} - -func TestReassambleFailsForHoles(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) - heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when there was a hole in the packet") - } -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go new file mode 100755 index 000000000..6b71d27c7 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go @@ -0,0 +1,38 @@ +// automatically generated by stateify. + +package fragmentation + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *reassemblerList) beforeSave() {} +func (x *reassemblerList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *reassemblerList) afterLoad() {} +func (x *reassemblerList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *reassemblerEntry) beforeSave() {} +func (x *reassemblerEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *reassemblerEntry) afterLoad() {} +func (x *reassemblerEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("fragmentation.reassemblerList", (*reassemblerList)(nil), state.Fns{Save: (*reassemblerList).save, Load: (*reassemblerList).load}) + state.Register("fragmentation.reassemblerEntry", (*reassemblerEntry)(nil), state.Fns{Save: (*reassemblerEntry).save, Load: (*reassemblerEntry).load}) +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go deleted file mode 100644 index 72c0f53be..000000000 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -// vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -type processInput struct { - id uint32 - first uint16 - last uint16 - more bool - vv buffer.VectorisedView -} - -type processOutput struct { - vv buffer.VectorisedView - done bool -} - -var processTestCases = []struct { - comment string - in []processInput - out []processOutput -}{ - { - comment: "One ID", - in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, - { - comment: "Two IDs", - in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "ab", "cd"), done: true}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, -} - -func TestFragmentationProcess(t *testing.T) { - for _, c := range processTestCases { - t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) - for i, in := range c.in { - vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) - if err != nil { - t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) - } - if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) - } - if done != c.out[i].done { - t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) - } - if c.out[i].done { - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Process(%d) did not remove buffer from reassemblers", i) - } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Process(%d) did not remove buffer from rList", i) - } - } - } - } - }) - } -} - -func TestReassemblingTimeout(t *testing.T) { - timeout := time.Millisecond - f := NewFragmentation(1024, 512, timeout) - // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(0, 0, 0, true, vv(1, "0")) - // Sleep more than the timeout. - time.Sleep(2 * timeout) - // Send another fragment that completes a packet. - // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) - if err != nil { - t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) - } - if done { - t.Errorf("Fragmentation does not respect the reassembling timeout.") - } -} - -func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(3, 1, DefaultReassembleTimeout) - // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) - // Send first fragment with id = 1. - f.Process(1, 0, 0, true, vv(1, "1")) - // Send first fragment with id = 2. - f.Process(2, 0, 0, true, vv(1, "2")) - - // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be - // evicted. - f.Process(3, 0, 0, true, vv(1, "3")) - - if _, ok := f.reassemblers[0]; ok { - t.Errorf("Memory limits are not respected: id=0 has not been evicted.") - } - if _, ok := f.reassemblers[1]; ok { - t.Errorf("Memory limits are not respected: id=1 has not been evicted.") - } - if _, ok := f.reassemblers[3]; !ok { - t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") - } -} - -func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(1, 0, DefaultReassembleTimeout) - // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) - // Send the same packet again. - f.Process(0, 0, 0, true, vv(1, "0")) - - got := f.size - want := 1 - if got != want { - t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) - } -} diff --git a/pkg/tcpip/network/fragmentation/reassembler_list.go b/pkg/tcpip/network/fragmentation/reassembler_list.go new file mode 100755 index 000000000..3189cae29 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/reassembler_list.go @@ -0,0 +1,173 @@ +package fragmentation + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type reassemblerElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type reassemblerList struct { + head *reassembler + tail *reassembler +} + +// Reset resets list l to the empty state. +func (l *reassemblerList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *reassemblerList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *reassemblerList) Front() *reassembler { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *reassemblerList) Back() *reassembler { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *reassemblerList) PushFront(e *reassembler) { + reassemblerElementMapper{}.linkerFor(e).SetNext(l.head) + reassemblerElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *reassemblerList) PushBack(e *reassembler) { + reassemblerElementMapper{}.linkerFor(e).SetNext(nil) + reassemblerElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *reassemblerList) PushBackList(m *reassemblerList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) + reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *reassemblerList) InsertAfter(b, e *reassembler) { + a := reassemblerElementMapper{}.linkerFor(b).Next() + reassemblerElementMapper{}.linkerFor(e).SetNext(a) + reassemblerElementMapper{}.linkerFor(e).SetPrev(b) + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *reassemblerList) InsertBefore(a, e *reassembler) { + b := reassemblerElementMapper{}.linkerFor(a).Prev() + reassemblerElementMapper{}.linkerFor(e).SetNext(a) + reassemblerElementMapper{}.linkerFor(e).SetPrev(b) + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *reassemblerList) Remove(e *reassembler) { + prev := reassemblerElementMapper{}.linkerFor(e).Prev() + next := reassemblerElementMapper{}.linkerFor(e).Next() + + if prev != nil { + reassemblerElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type reassemblerEntry struct { + next *reassembler + prev *reassembler +} + +// Next returns the entry that follows e in the list. +func (e *reassemblerEntry) Next() *reassembler { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *reassemblerEntry) Prev() *reassembler { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *reassemblerEntry) SetNext(elem *reassembler) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *reassemblerEntry) SetPrev(elem *reassembler) { + e.prev = elem +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go deleted file mode 100644 index 7eee0710d..000000000 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "math" - "reflect" - "testing" -) - -type updateHolesInput struct { - first uint16 - last uint16 - more bool -} - -var holesTestCases = []struct { - comment string - in []updateHolesInput - want []hole -}{ - { - comment: "No fragments. Expected holes: {[0 -> inf]}.", - in: []updateHolesInput{}, - want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, - }, - { - comment: "One fragment at beginning. Expected holes: {[2, inf]}.", - in: []updateHolesInput{{first: 0, last: 1, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: false}, - }, - }, - { - comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", - in: []updateHolesInput{{first: 1, last: 2, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - {first: 3, last: math.MaxUint16, deleted: false}, - }, - }, - { - comment: "One fragment at the end. Expected holes: {[0, 0]}.", - in: []updateHolesInput{{first: 1, last: 2, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - }, - }, - { - comment: "One fragment completing a packet. Expected holes: {}.", - in: []updateHolesInput{{first: 0, last: 1, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - }, - }, - { - comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 1, more: true}, - {first: 2, last: 3, more: false}, - }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: true}, - }, - }, - { - comment: "Two overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 2, more: true}, - {first: 2, last: 3, more: false}, - }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 3, last: math.MaxUint16, deleted: true}, - }, - }, -} - -func TestUpdateHoles(t *testing.T) { - for _, c := range holesTestCases { - r := newReassembler(0) - for _, i := range c.in { - r.updateHoles(i.first, i.last, i.more) - } - if !reflect.DeepEqual(r.holes, c.want) { - t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) - } - } -} diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD deleted file mode 100644 index e6db5c0b0..000000000 --- a/pkg/tcpip/network/hash/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hash", - srcs = ["hash.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash", - visibility = ["//visibility:public"], - deps = [ - "//pkg/rand", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go new file mode 100755 index 000000000..a3bcd4b69 --- /dev/null +++ b/pkg/tcpip/network/hash/hash_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package hash + diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go deleted file mode 100644 index 4144a7837..000000000 --- a/pkg/tcpip/network/ip_test.go +++ /dev/null @@ -1,648 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - localIpv4Addr = "\x0a\x00\x00\x01" - localIpv4PrefixLen = 24 - remoteIpv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - localIpv6PrefixLen = 120 - remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" -) - -// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. -// The former is used to pretend that it's a link endpoint so that we can -// inspect packets written by the network endpoints. The latter is used to -// pretend that it's the network stack so that it can inspect incoming packets -// that have been handled by the network endpoints. -// -// Packets are checked by comparing their fields/values against the expected -// values stored in the test object itself. -type testObject struct { - t *testing.T - protocol tcpip.TransportProtocolNumber - contents []byte - srcAddr tcpip.Address - dstAddr tcpip.Address - v4 bool - typ stack.ControlType - extra uint32 - - dataCalls int - controlCalls int -} - -// checkValues verifies that the transport protocol, data contents, src & dst -// addresses of a packet match what's expected. If any field doesn't match, the -// test fails. -func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { - v := vv.ToView() - if protocol != t.protocol { - t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) - } - - if srcAddr != t.srcAddr { - t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) - } - - if dstAddr != t.dstAddr { - t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) - } - - if len(v) != len(t.contents) { - t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) - } - - for i := range t.contents { - if t.contents[i] != v[i] { - t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) - } - } -} - -// DeliverTransportPacket is called by network endpoints after parsing incoming -// packets. This is used by the test object to verify that the results of the -// parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { - t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) - t.dataCalls++ -} - -// DeliverTransportControlPacket is called by network endpoints after parsing -// incoming control (ICMP) packets. This is used by the test object to verify -// that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { - t.checkValues(trans, pkt.Data, remote, local) - if typ != t.typ { - t.t.Errorf("typ = %v, want %v", typ, t.typ) - } - if extra != t.extra { - t.t.Errorf("extra = %v, want %v", extra, t.extra) - } - t.controlCalls++ -} - -// Attach is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (*testObject) IsAttached() bool { - return true -} - -// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that -// matches the linux loopback MTU. -func (*testObject) MTU() uint32 { - return 65536 -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (*testObject) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (*testObject) LinkAddress() tcpip.LinkAddress { - return "" -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*testObject) Wait() {} - -// WritePacket is called by network endpoints after producing a packet and -// writing it to the link endpoint. This is used by the test object to verify -// that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - var prot tcpip.TransportProtocolNumber - var srcAddr tcpip.Address - var dstAddr tcpip.Address - - if t.v4 { - h := header.IPv4(pkt.Header.View()) - prot = tcpip.TransportProtocolNumber(h.Protocol()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - - } else { - h := header.IPv6(pkt.Header.View()) - prot = tcpip.TransportProtocolNumber(h.NextHeader()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - } - t.checkValues(prot, pkt.Data, srcAddr, dstAddr) - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - panic("not implemented") -} - -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - return tcpip.ErrNotSupported -} - -func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, - }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv4.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - Gateway: ipv4Gateway, - NIC: 1, - }}) - - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) -} - -func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, - }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv6.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: ipv6Gateway, - NIC: 1, - }}) - - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) -} - -func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) - - // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv4MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: view.ToVectorisedView(), - }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv4ReceiveControl(t *testing.T) { - const mtu = 0xbeef - header.IPv4MinimumSize - cases := []struct { - name string - expectedCount int - fragmentOffset uint16 - code uint8 - expectedTyp stack.ControlType - expectedExtra uint32 - trunc int - }{ - {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, - {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, - {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, - } - r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") - if err != nil { - t.Fatal(err) - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize - view := buffer.NewView(dataOffset + 8) - - // Create the outer IPv4 header. - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(view) - c.trunc), - TTL: 20, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIpv4Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) - icmp.SetType(header.ICMPv4DstUnreachable) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: 100, - TTL: 20, - Protocol: 10, - FragmentOffset: c.fragmentOffset, - SrcAddr: localIpv4Addr, - DstAddr: remoteIpv4Addr, - }) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv4 endpoint, dispatcher will validate that - // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra - - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: vv, - }) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) - } - }) - } -} - -func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv4MinimumSize + 24 - - frag1 := buffer.NewView(totalLen) - ip1 := header.IPv4(frag1) - ip1.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 0, - Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag1[i] = uint8(i) - } - - frag2 := buffer.NewView(totalLen) - ip2 := header.IPv4(frag2) - ip2.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 24, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag2[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - // Send first segment. - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: frag1.ToVectorisedView(), - }) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) - } - - // Send second segment. - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: frag2.ToVectorisedView(), - }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv6Send(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) - - // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload - - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv6MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(totalLen - header.IPv6MinimumSize), - NextHeader: 10, - HopLimit: 20, - SrcAddr: remoteIpv6Addr, - DstAddr: localIpv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] - - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: view.ToVectorisedView(), - }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv6ReceiveControl(t *testing.T) { - newUint16 := func(v uint16) *uint16 { return &v } - - const mtu = 0xffff - const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" - cases := []struct { - name string - expectedCount int - fragmentOffset *uint16 - typ header.ICMPv6Type - code uint8 - expectedTyp stack.ControlType - expectedExtra uint32 - trunc int - }{ - {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, - {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, - {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, - {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, - } - r, err := buildIPv6Route( - localIpv6Addr, - "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - ) - if err != nil { - t.Fatal(err) - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - defer ep.Close() - - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize - if c.fragmentOffset != nil { - dataOffset += header.IPv6FragmentHeaderSize - } - view := buffer.NewView(dataOffset + 8) - - // Create the outer IPv6 header. - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIpv6Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) - icmp.SetType(c.typ) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - NextHeader: 10, - HopLimit: 20, - SrcAddr: localIpv6Addr, - DstAddr: remoteIpv6Addr, - }) - - // Build the fragmentation header if needed. - if c.fragmentOffset != nil { - ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) - frag.Encode(&header.IPv6FragmentFields{ - NextHeader: 10, - FragmentOffset: *c.fragmentOffset, - M: true, - Identification: 0x12345678, - }) - } - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv6 endpoint, dispatcher will validate that - // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra - - // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: view[:len(view)-c.trunc].ToVectorisedView(), - }) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD deleted file mode 100644 index aeddfcdd4..000000000 --- a/pkg/tcpip/network/ipv4/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv4", - srcs = [ - "icmp.go", - "ipv4.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/network/fragmentation", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv4_test", - size = "small", - srcs = ["ipv4_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go new file mode 100755 index 000000000..6b2cc0142 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package ipv4 + diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go deleted file mode 100644 index e900f1b45..000000000 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ /dev/null @@ -1,475 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "bytes" - "encoding/hex" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestExcludeBroadcast(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - }) - - const defaultMTU = 65536 - ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) - if testing.Verbose() { - ep = sniffer.New(ep) - } - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }}) - - randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} - - var wq waiter.Queue - t.Run("WithoutPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Cannot connect using a broadcast address as the source. - if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) - } - - // However, we can bind to a broadcast address to listen. - if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { - t.Errorf("Bind failed: %v", err) - } - }) - - t.Run("WithPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := ep.Connect(randomAddr); err != nil { - t.Errorf("Connect failed: %v", err) - } - }) -} - -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much -// data should already be in the header before WritePacket. extraLength -// indicates how much extra space should be in the header. The payload is made -// from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - - var views []buffer.View - totalLength := 0 - for _, s := range viewSizes { - newView := buffer.NewView(s) - rand.Read(newView) - views = append(views, newView) - totalLength += s - } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload -} - -// comparePayloads compared the contents of all the packets against the contents -// of the source packet. -func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { - t.Helper() - // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Data.ToView()...) - - // Make a copy of the IP header, which will be modified in some fields to make - // an expected header. - sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) - sourceCopy.SetChecksum(0) - sourceCopy.SetFlagsFragmentOffset(0, 0) - sourceCopy.SetTotalLength(0) - var offset uint16 - // Build up an array of the bytes sent. - var reassembledPayload []byte - for i, packet := range packets { - // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Data) - ip := header.IPv4(allBytes.ToView()) - if !ip.IsValid(len(ip)) { - t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) - } - if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { - t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) - } - if got, want := len(ip), int(mtu); got > want { - t.Errorf("fragment is too large, got %d want %d", got, want) - } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) - } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { - t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) - } - if i < len(packets)-1 { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) - } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) - } - reassembledPayload = append(reassembledPayload, ip.Payload()...) - offset += ip.TotalLength() - uint16(ip.HeaderLength()) - // Clear out the checksum and length from the ip because we can't compare - // it. - sourceCopy.SetTotalLength(uint16(len(ip))) - sourceCopy.SetChecksum(0) - sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { - t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) - } - } - expected := source[source.HeaderLength():] - if !bytes.Equal(reassembledPayload, expected) { - t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) - } -} - -type errorChannel struct { - *channel.Endpoint - Ch chan tcpip.PacketBuffer - packetCollectorErrors []*tcpip.Error -} - -// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket -// will return successive errors from packetCollectorErrors until the list is -// empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { - return &errorChannel{ - Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan tcpip.PacketBuffer, size), - packetCollectorErrors: packetCollectorErrors, - } -} - -// Drain removes all outbound packets from the channel and counts them. -func (e *errorChannel) Drain() int { - c := 0 - for { - select { - case <-e.Ch: - c++ - default: - return c - } - } -} - -// WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { - select { - case e.Ch <- pkt: - default: - } - - nextError := (*tcpip.Error)(nil) - if len(e.packetCollectorErrors) > 0 { - nextError = e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - } - return nextError -} - -type context struct { - stack.Route - linkEP *errorChannel -} - -func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { - // Make the packet and write it. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - s.CreateNIC(1, ep) - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - s.AddAddress(1, ipv4.ProtocolNumber, src) - { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - return context{ - Route: r, - linkEP: ep, - } -} - -func TestFragmentation(t *testing.T) { - var manyPayloadViewsSizes [1000]int - for i := range manyPayloadViewsSizes { - manyPayloadViewsSizes[i] = 7 - } - fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - hdrLength int - extraLength int - payloadViewsSizes []int - expectedFrags int - }{ - {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, - {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, - } - - for _, ft := range fragTests { - t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := tcpip.PacketBuffer{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), - } - c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: payload, - }) - if err != nil { - t.Errorf("err got %v, want %v", err, nil) - } - - var results []tcpip.PacketBuffer - L: - for { - select { - case pi := <-c.linkEP.Ch: - results = append(results, pi) - default: - break L - } - } - - if got, want := len(results), ft.expectedFrags; got != want { - t.Errorf("len(result) got %d, want %d", got, want) - } - if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(result) got %d, want %d", got, want) - } - compareFragments(t, results, source, ft.mtu) - }) - } -} - -// TestFragmentationErrors checks that errors are returned from write packet -// correctly. -func TestFragmentationErrors(t *testing.T) { - fragTests := []struct { - description string - mtu uint32 - hdrLength int - payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error - }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, - } - - for _, ft := range fragTests { - t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) - c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: payload, - }) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } - } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) - } - if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) - } - }) - } -} - -func TestInvalidFragments(t *testing.T) { - // These packets have both IHL and TotalLength set to 0. - testCases := []struct { - name string - packets [][]byte - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - }{ - { - "ihl_totallen_zero_valid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 0, - }, - { - "ihl_totallen_zero_invalid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 0, - }, - { - // Total Length of 37(20 bytes IP header + 17 bytes of - // payload) - // Frag Offset of 0x1ffe = 8190*8 = 65520 - // Leading to the fragment end to be past 65535. - "ihl_totallen_valid_invalid_frag_offset_1", - [][]byte{ - {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 1, - }, - // The following 3 tests were found by running a fuzzer and were - // triggering a panic in the IPv4 reassembler code. - { - "ihl_less_than_ipv4_minimum_size_1", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 2, - 0, - }, - { - "ihl_less_than_ipv4_minimum_size_2", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 2, - 0, - }, - { - "ihl_less_than_ipv4_minimum_size_3", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 2, - 0, - }, - { - "fragment_with_short_total_len_extra_payload", - [][]byte{ - {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 1, - }, - { - "multiple_fragments_with_more_fragments_set_to_false", - [][]byte{ - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - }, - 1, - 1, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - const nicID tcpip.NICID = 42 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - }, - }) - - var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) - var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) - ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicID, sniffer.New(ep)) - - for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), - }) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { - t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { - t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD deleted file mode 100644 index e4e273460..000000000 --- a/pkg/tcpip/network/ipv6/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv6", - srcs = [ - "icmp.go", - "ipv6.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv6_test", - size = "small", - srcs = [ - "icmp_test.go", - "ipv6_test.go", - "ndp_test.go", - ], - embed = [":ipv6"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go deleted file mode 100644 index 335f634d5..000000000 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ /dev/null @@ -1,899 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "reflect" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") -) - -var ( - lladdr0 = header.LinkLocalAddr(linkAddr0) - lladdr1 = header.LinkLocalAddr(linkAddr1) -) - -type stubLinkEndpoint struct { - stack.LinkEndpoint -} - -func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -func (*stubLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { - return nil -} - -func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -type stubDispatcher struct { - stack.TransportDispatcher -} - -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { -} - -type stubLinkAddressCache struct { - stack.LinkAddressCache -} - -func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID { - return 0 -} - -func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { -} - -func TestICMPCounts(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, - }) - { - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } - - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - defer r.Release() - - types := []struct { - typ header.ICMPv6Type - size int - }{ - {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize}, - {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize}, - {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize}, - {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize}, - {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, - {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, - {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, - {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, - } - - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - ep.HandlePacket(&r, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - for _, typ := range types { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - handleIPv6Payload(hdr) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) - - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } -} - -func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { - t := v.Type() - for i := 0; i < v.NumField(); i++ { - v := v.Field(i) - if s, ok := v.Interface().(*tcpip.StatCounter); ok { - f(t.Field(i).Name, s) - } else { - visitStats(v, f) - } - } -} - -type testContext struct { - s0 *stack.Stack - s1 *stack.Stack - - linkEP0 *channel.Endpoint - linkEP1 *channel.Endpoint -} - -type endpointWithResolutionCapability struct { - stack.LinkEndpoint -} - -func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities { - return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired -} - -func newTestContext(t *testing.T) *testContext { - c := &testContext{ - s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, - }), - s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, - }), - } - - const defaultMTU = 65536 - c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) - - wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) - if testing.Verbose() { - wrappedEP0 = sniffer.New(wrappedEP0) - } - if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { - t.Fatalf("CreateNIC s0: %v", err) - } - if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) - } - - c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) - wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) - if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) - } - - subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - c.s0.SetRouteTable( - []tcpip.Route{{ - Destination: subnet0, - NIC: 1, - }}, - ) - subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) - if err != nil { - t.Fatal(err) - } - c.s1.SetRouteTable( - []tcpip.Route{{ - Destination: subnet1, - NIC: 1, - }}, - ) - - return c -} - -func (c *testContext) cleanup() { - close(c.linkEP0.C) - close(c.linkEP1.C) -} - -type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type -} - -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { - t.Helper() - - pi := <-args.src.C - - { - views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} - size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ - Data: vv, - }) - } - - if pi.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pi.Proto) - return - } - ipv6 := header.IPv6(pi.Pkt.Header.View()) - transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) - if transProto != header.ICMPv6ProtocolNumber { - t.Errorf("unexpected transport protocol number %d", transProto) - return - } - icmpv6 := header.ICMPv6(ipv6.Payload()) - if got, want := icmpv6.Type(), args.typ; got != want { - t.Errorf("got ICMPv6 type = %d, want = %d", got, want) - return - } - if fn != nil { - fn(t, icmpv6) - } -} - -func TestLinkResolution(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - defer r.Release() - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - payload := tcpip.SlicePayload(hdr.View()) - - // We can't send our payload directly over the route because that - // doesn't provoke NDP discovery. - var wq waiter.Queue - ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } - - for { - _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}) - if resCh != nil { - if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) - } - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, - } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { - if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { - t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) - } - }) - } - <-resCh - continue - } - if err != nil { - t.Fatalf("ep.Write(_) = _, _, %s", err) - } - break - } - - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, - } { - routeICMPv6Packet(t, args, nil) - } -} - -func TestICMPChecksumValidationSimple(t *testing.T) { - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - }, - { - "RouterSolicit", - header.ICMPv6RouterSolicit, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - }, - { - "RouterAdvert", - header.ICMPv6RouterAdvert, - header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - "NeighborSolicit", - header.ICMPv6NeighborSolicit, - header.ICMPv6NeighborSolicitMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - "NeighborAdvert", - header.ICMPv6NeighborAdvert, - header.ICMPv6NeighborAdvertSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - "RedirectMsg", - header.ICMPv6RedirectMsg, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) - if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestICMPChecksumValidationWithPayload(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - icmpSize := size + payloadSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) - - if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) - - payload := buffer.NewView(payloadSize) - payloadFn(payload) - - if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), - }) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go new file mode 100755 index 000000000..53319e0c4 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package ipv6 + diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go deleted file mode 100644 index 1cbfa7278..000000000 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // The least significant 3 bytes are the same as addr2 so both addr2 and - // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" -) - -// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the -// expected Neighbor Advertisement received count after receiving the packet. -func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - // Receive ICMP packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - stats := s.Stats().ICMP.V6PacketsReceived - - if got := stats.NeighborAdvert.Value(); got != want { - t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) - } -} - -// testReceiveUDP tests receiving a UDP packet from src to dst. want is the -// expected UDP received count after receiving the packet. -func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - - // Receive UDP Packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - - // UDP pseudo-header checksum. - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) - - // UDP checksum - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - stat := s.Stats().UDP.PacketsReceived - - if got := stat.Value(); got != want { - t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) - } -} - -// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and -// UDP packets destined to the IPv6 link-local all-nodes multicast address. -func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocol - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, - }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should receive a packet destined to the all-nodes - // multicast address. - test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) - }) - } -} - -// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP -// packets destined to the IPv6 solicited-node address of an assigned IPv6 -// address. -func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocol - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, - } - - snmc := header.SolicitedNodeAddr(addr2) - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, - }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as we haven't added - // those addresses. - test.rxf(t, s, e, addr1, snmc, 0) - - if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) - } - - // Should receive a packet destined to the solicited - // node address of addr2/addr3 now that we have added - // added addr2. - test.rxf(t, s, e, addr1, snmc, 1) - - if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) - } - - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have added addr3. - test.rxf(t, s, e, addr1, snmc, 2) - - if err := s.RemoveAddress(1, addr2); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) - } - - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have removed addr2. - test.rxf(t, s, e, addr1, snmc, 3) - - if err := s.RemoveAddress(1, addr3); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) - } - - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as both of them got - // removed. - test.rxf(t, s, e, addr1, snmc, 3) - }) - } -} - -// TestAddIpv6Address tests adding IPv6 addresses. -func TestAddIpv6Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - }{ - // This test is in response to b/140943433. - { - "Nil", - tcpip.Address([]byte(nil)), - }, - { - "ValidUnicast", - addr1, - }, - { - "ValidLinkLocalUnicast", - lladdr0, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) - } - - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go deleted file mode 100644 index 0dbce14a0..000000000 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } - - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } - - return s, ep -} - -// TestHopLimitValidation is a test that makes sure that NDP packets are only -// received if their IP header's hop limit is set to 255. -func TestHopLimitValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r - } - - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - ep.HandlePacket(r, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - }{ - {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }}, - {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }}, - {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }}, - {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }}, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - // Should not have received any ICMPv6 packets with - // type = typ.typ. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Receive the NDP packet with an invalid hop limit - // value. - handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) - - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - // Rx count of NDP packet of type typ.typ should not - // have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) - - // Rx count of NDP packet of type typ.typ should have - // increased. - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -// TestRouterAdvertValidation tests that when the NIC is configured to handle -// NDP Router Advertisement packets, it validates the Router Advertisement -// properly before handling them. -func TestRouterAdvertValidation(t *testing.T) { - tests := []struct { - name string - src tcpip.Address - hopLimit uint8 - code uint8 - ndpPayload []byte - expectedSuccess bool - }{ - { - "OK", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - true, - }, - { - "NonLinkLocalSourceAddr", - addr1, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "HopLimitNot255", - lladdr0, - 254, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NonZeroCode", - lladdr0, - 255, - 1, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NDPPayloadTooSmall", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, - }, - false, - }, - { - "OKWithOptions", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - 2, 1, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - true, - }, - { - "OptionWithZeroLength", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - // Invalid as it has 0 length. - 2, 0, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } - - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } - - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } - } - }) - } -} diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go index ab24372e7..ab24372e7 100644..100755 --- a/pkg/tcpip/packet_buffer.go +++ b/pkg/tcpip/packet_buffer.go diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go index ad3cc24fa..ad3cc24fa 100644..100755 --- a/pkg/tcpip/packet_buffer_state.go +++ b/pkg/tcpip/packet_buffer_state.go diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD deleted file mode 100644 index e156b01f6..000000000 --- a/pkg/tcpip/ports/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "ports", - srcs = ["ports.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - ], -) - -go_test( - name = "ports_test", - srcs = ["ports_test.go"], - embed = [":ports"], - deps = [ - "//pkg/tcpip", - ], -) diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go new file mode 100755 index 000000000..badae1397 --- /dev/null +++ b/pkg/tcpip/ports/ports_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package ports + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Flags) beforeSave() {} +func (x *Flags) save(m state.Map) { + x.beforeSave() + m.Save("MostRecent", &x.MostRecent) + m.Save("LoadBalanced", &x.LoadBalanced) +} + +func (x *Flags) afterLoad() {} +func (x *Flags) load(m state.Map) { + m.Load("MostRecent", &x.MostRecent) + m.Load("LoadBalanced", &x.LoadBalanced) +} + +func init() { + state.Register("ports.Flags", (*Flags)(nil), state.Fns{Save: (*Flags).save, Load: (*Flags).load}) +} diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go deleted file mode 100644 index d6969d050..000000000 --- a/pkg/tcpip/ports/ports_test.go +++ /dev/null @@ -1,402 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ports - -import ( - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 - - fakeIPAddress = tcpip.Address("\x08\x08\x08\x08") - fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") -) - -type portReserveTestAction struct { - port uint16 - ip tcpip.Address - want *tcpip.Error - flags Flags - release bool - device tcpip.NICID -} - -func TestPortReservation(t *testing.T) { - for _, test := range []struct { - tname string - actions []portReserveTestAction - }{ - { - tname: "bind to ip", - actions: []portReserveTestAction{ - {port: 80, ip: fakeIPAddress, want: nil}, - {port: 80, ip: fakeIPAddress1, want: nil}, - /* N.B. Order of tests matters! */ - {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, - }, - }, - { - tname: "bind to inaddr any", - actions: []portReserveTestAction{ - {port: 22, ip: anyIPAddress, want: nil}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - /* release fakeIPAddress, but anyIPAddress is still inuse */ - {port: 22, ip: fakeIPAddress, release: true}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, - /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ - {port: 22, ip: anyIPAddress, want: nil, release: true}, - {port: 22, ip: fakeIPAddress, want: nil}, - }, - }, { - tname: "bind to zero port", - actions: []portReserveTestAction{ - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind to ip with reuseport", - actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - - {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - - {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind to inaddr any with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, - }, - }, { - tname: "bind twice with device fails", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 3, want: nil}, - {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind to device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 1, want: nil}, - {port: 24, ip: fakeIPAddress, device: 2, want: nil}, - }, - }, { - tname: "bind to device and then without device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind without device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, want: nil}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "binding with reuseport and device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "mixing reuseport and not reuseport by binding to device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: nil}, - }, - }, { - tname: "can't bind to 0 after mixing reuseport and not reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind and release", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - - // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, - }, - }, { - tname: "bind twice with reuseport once", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "release an unreserved device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, - // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, - // Release all. - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, - }, - }, { - tname: "bind with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, - }, - }, { - tname: "bind twice with reuseaddr once", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr and reuseport, and then reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseaddr and reuseport, and then reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr and reuseport twice, and then reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr, and then reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - }, - }, { - tname: "bind with reuseport, and then reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, - }, - }, - } { - t.Run(test.tname, func(t *testing.T) { - pm := NewPortManager() - net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} - - for _, test := range test.actions { - if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) - continue - } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) - if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) - } - if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) - } - } - }) - - } -} - -func TestPickEphemeralPort(t *testing.T) { - customErr := &tcpip.Error{} - for _, test := range []struct { - name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr - }, - wantErr: customErr, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port == FirstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: FirstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port < FirstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - } { - t.Run(test.name, func(t *testing.T) { - pm := NewPortManager() - if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) - } - }) - } -} - -func TestPickEphemeralPortStable(t *testing.T) { - customErr := &tcpip.Error{} - for _, test := range []struct { - name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr - }, - wantErr: customErr, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port == FirstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: FirstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port < FirstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - } { - t.Run(test.name, func(t *testing.T) { - pm := NewPortManager() - portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) - if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) - } - }) - } -} diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD deleted file mode 100644 index a57752a7c..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_connect", - srcs = ["main.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go deleted file mode 100644 index 2239c1e66..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and connects to a peer. Similar to "nc <address> <port>". While the -// sample is running, attempts to connect to its IPv4 address will result in -// a RST segment. -// -// As an example of how to run it, a TUN device can be created and enabled on -// a linux host as follows (this only needs to be done once per boot): -// -// [sudo] ip tuntap add user <username> mode tun <device-name> -// [sudo] ip link set <device-name> up -// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name> -// -// A concrete example: -// -// $ sudo ip tuntap add user wedsonaf mode tun tun0 -// $ sudo ip link set tun0 up -// $ sudo ip addr add 192.168.1.1/24 dev tun0 -// -// Then one can run tun_tcp_connect as such: -// -// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234 -// -// This will attempt to connect to the linux host's stack. One can run nc in -// listen mode to accept a connect from tun_tcp_connect and exchange data. -package main - -import ( - "bufio" - "fmt" - "log" - "math/rand" - "net" - "os" - "strconv" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// writer reads from standard input and writes to the endpoint until standard -// input is closed. It signals that it's done by closing the provided channel. -func writer(ch chan struct{}, ep tcpip.Endpoint) { - defer func() { - ep.Shutdown(tcpip.ShutdownWrite) - close(ch) - }() - - r := bufio.NewReader(os.Stdin) - for { - v := buffer.NewView(1024) - n, err := r.Read(v) - if err != nil { - return - } - - v.CapLength(n) - for len(v) > 0 { - n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - if err != nil { - fmt.Println("Write failed:", err) - return - } - - v.TrimFront(int(n)) - } - } -} - -func main() { - if len(os.Args) != 6 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>") - } - - tunName := os.Args[1] - addrName := os.Args[2] - portName := os.Args[3] - remoteAddrName := os.Args[4] - remotePortName := os.Args[5] - - rand.Seed(time.Now().UnixNano()) - - addr := tcpip.Address(net.ParseIP(addrName).To4()) - remote := tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()), - } - - var localPort uint16 - if v, err := strconv.Atoi(portName); err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } else { - localPort = uint16(v) - } - - if v, err := strconv.Atoi(remotePortName); err != nil { - log.Fatalf("Unable to convert port %v: %v", remotePortName, err) - } else { - remote.Port = uint16(v) - } - - // Create the stack with ipv4 and tcp protocols, then add a tun-based - // NIC and ipv4 address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - fd, err := tun.Open(tunName) - if err != nil { - log.Fatal(err) - } - - linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - }) - - // Create TCP endpoint. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - log.Fatal(e) - } - - // Bind if a port is specified. - if localPort != 0 { - if err := ep.Bind(tcpip.FullAddress{0, "", localPort}); err != nil { - log.Fatal("Bind failed: ", err) - } - } - - // Issue connect request and wait for it to complete. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - terr := ep.Connect(remote) - if terr == tcpip.ErrConnectStarted { - fmt.Println("Connect is pending...") - <-notifyCh - terr = ep.GetSockOpt(tcpip.ErrorOption{}) - } - wq.EventUnregister(&waitEntry) - - if terr != nil { - log.Fatal("Unable to connect: ", terr) - } - - fmt.Println("Connected") - - // Start the writer in its own goroutine. - writerCompletedCh := make(chan struct{}) - go writer(writerCompletedCh, ep) // S/R-SAFE: sample code. - - // Read data and write to standard output until the peer closes the - // connection from its side. - wq.EventRegister(&waitEntry, waiter.EventIn) - for { - v, _, err := ep.Read(nil) - if err != nil { - if err == tcpip.ErrClosedForReceive { - break - } - - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - log.Fatal("Read() failed:", err) - } - - os.Stdout.Write(v) - } - wq.EventUnregister(&waitEntry) - - // The reader has completed. Now wait for the writer as well. - <-writerCompletedCh - - ep.Close() -} diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD deleted file mode 100644 index dad8ef399..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_echo", - srcs = ["main.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go deleted file mode 100644 index bca73cbb1..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and listens on a port. Data received by the server in the accepted -// connections is echoed back to the clients. -package main - -import ( - "flag" - "log" - "math/rand" - "net" - "os" - "strconv" - "strings" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var tap = flag.Bool("tap", false, "use tap istead of tun") -var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") - -func echo(wq *waiter.Queue, ep tcpip.Endpoint) { - defer ep.Close() - - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - - for { - v, _, err := ep.Read(nil) - if err != nil { - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - return - } - - ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - } -} - -func main() { - flag.Parse() - if len(flag.Args()) != 3 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>") - } - - tunName := flag.Arg(0) - addrName := flag.Arg(1) - portName := flag.Arg(2) - - rand.Seed(time.Now().UnixNano()) - - // Parse the mac address. - maddr, err := net.ParseMAC(*mac) - if err != nil { - log.Fatalf("Bad MAC address: %v", *mac) - } - - // Parse the IP address. Support both ipv4 and ipv6. - parsedAddr := net.ParseIP(addrName) - if parsedAddr == nil { - log.Fatalf("Bad IP address: %v", addrName) - } - - var addr tcpip.Address - var proto tcpip.NetworkProtocolNumber - if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) - proto = ipv4.ProtocolNumber - } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) - proto = ipv6.ProtocolNumber - } else { - log.Fatalf("Unknown IP type: %v", addrName) - } - - localPort, err := strconv.Atoi(portName) - if err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } - - // Create the stack with ip and tcp protocols, then add a tun-based - // NIC and address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - var fd int - if *tap { - fd, err = tun.OpenTAP(tunName) - } else { - fd, err = tun.Open(tunName) - } - if err != nil { - log.Fatal(err) - } - - linkEP, err := fdbased.New(&fdbased.Options{ - FDs: []int{fd}, - MTU: mtu, - EthernetHeader: *tap, - Address: tcpip.LinkAddress(maddr), - }) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, linkEP); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - log.Fatal(err) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) - if err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: subnet, - NIC: 1, - }, - }) - - // Create TCP endpoint, bind it, then start listening. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) - if err != nil { - log.Fatal(e) - } - - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}); err != nil { - log.Fatal("Bind failed: ", err) - } - - if err := ep.Listen(10); err != nil { - log.Fatal("Listen failed: ", err) - } - - // Wait for connections to appear. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - - for { - n, wq, err := ep.Accept() - if err != nil { - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - log.Fatal("Accept() failed:", err) - } - - go echo(wq, n) // S/R-SAFE: sample code. - } -} diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD deleted file mode 100644 index b31ddba2f..000000000 --- a/pkg/tcpip/seqnum/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "seqnum", - srcs = ["seqnum.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", - visibility = ["//visibility:public"], -) diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go new file mode 100755 index 000000000..bf76f6ac4 --- /dev/null +++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package seqnum + diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD deleted file mode 100644 index 69077669a..000000000 --- a/pkg/tcpip/stack/BUILD +++ /dev/null @@ -1,87 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_template_instance( - name = "linkaddrentry_list", - out = "linkaddrentry_list.go", - package = "stack", - prefix = "linkAddrEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*linkAddrEntry", - "Linker": "*linkAddrEntry", - }, -) - -go_library( - name = "stack", - srcs = [ - "icmp_rate_limit.go", - "linkaddrcache.go", - "linkaddrentry_list.go", - "ndp.go", - "nic.go", - "registration.go", - "route.go", - "stack.go", - "stack_global_state.go", - "transport_demuxer.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", - visibility = ["//visibility:public"], - deps = [ - "//pkg/ilist", - "//pkg/rand", - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/waiter", - "@org_golang_x_time//rate:go_default_library", - ], -) - -go_test( - name = "stack_x_test", - size = "small", - srcs = [ - "ndp_test.go", - "stack_test.go", - "transport_demuxer_test.go", - "transport_test.go", - ], - deps = [ - ":stack", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stack_test", - size = "small", - srcs = ["linkaddrcache_test.go"], - embed = [":stack"], - deps = [ - "//pkg/sleep", - "//pkg/tcpip", - ], -) diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go deleted file mode 100644 index 9946b8fe8..000000000 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "sync" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sleep" - "gvisor.dev/gvisor/pkg/tcpip" -) - -type testaddr struct { - addr tcpip.FullAddress - linkAddr tcpip.LinkAddress -} - -var testAddrs = func() []testaddr { - var addrs []testaddr - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - addrs = append(addrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } - return addrs -}() - -type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration - onLinkAddressRequest func() -} - -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) - if f := r.onLinkAddressRequest; f != nil { - f() - } - return nil -} - -func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testAddrs { - if ta.addr.Addr == addr { - r.cache.add(ta.addr, ta.linkAddr) - break - } - } -} - -func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "broadcast" { - return "mac_broadcast", true - } - return "", false -} - -func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 1 -} - -func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - - for { - if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - return got, err - } - s.Fetch(true) - } -} - -func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testAddrs) - 1; i >= 0; i-- { - e := testAddrs[i] - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) - } - } - // Expect to find at least half of the most recent entries. - for i := 0; i < linkAddrCacheSize/2; i++ { - e := testAddrs[i] - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) - } - } - // The earliest entries should no longer be in the cache. - for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { - e := testAddrs[i] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) - } - } -} - -func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - - var wg sync.WaitGroup - for r := 0; r < 16; r++ { - wg.Add(1) - go func() { - for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) - c.get(e.addr, nil, "", nil, nil) // make work for gotsan - } - wg.Done() - }() - } - wg.Wait() - - // All goroutines add in the same order and add more values than - // can fit in the cache, so our eviction strategy requires that - // the last entry be present and the first be missing. - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - e = testAddrs[0] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testAddrs[0] - c.add(e.addr, e.linkAddr) - time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testAddrs[0] - l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - c.add(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != l2 { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) - } -} - -func TestCacheResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) - linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testAddrs { - got, err := getBlocking(c, ta.addr, linkRes) - if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) - } - if got != ta.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) - } - } - - // Check that after resolved, address stays in the cache and never returns WouldBlock. - for i := 0; i < 10; i++ { - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - } -} - -func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) - linkRes := &testLinkAddressResolver{cache: c} - - var requestCount uint32 - linkRes.onLinkAddressRequest = func() { - atomic.AddUint32(&requestCount, 1) - } - - // First, sanity check that resolution is working... - e := testAddrs[0] - got, err := getBlocking(c, e.addr, linkRes) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - before := atomic.LoadUint32(&requestCount) - - e.addr.Addr += "2" - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } - - if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { - t.Errorf("got link address request count = %d, want = %d", got, want) - } -} - -func TestCacheResolutionTimeout(t *testing.T) { - resolverDelay := 500 * time.Millisecond - expiration := resolverDelay / 10 - c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) - linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - - e := testAddrs[0] - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -// TestStaticResolution checks that static link addresses are resolved immediately and don't -// send resolution requests. -func TestStaticResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, time.Millisecond, 1) - linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute} - - addr := tcpip.Address("broadcast") - want := tcpip.LinkAddress("mac_broadcast") - got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err) - } - if got != want { - t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) - } -} diff --git a/pkg/tcpip/stack/linkaddrentry_list.go b/pkg/tcpip/stack/linkaddrentry_list.go new file mode 100755 index 000000000..61a45ddcb --- /dev/null +++ b/pkg/tcpip/stack/linkaddrentry_list.go @@ -0,0 +1,173 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type linkAddrEntryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (linkAddrEntryElementMapper) linkerFor(elem *linkAddrEntry) *linkAddrEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type linkAddrEntryList struct { + head *linkAddrEntry + tail *linkAddrEntry +} + +// Reset resets list l to the empty state. +func (l *linkAddrEntryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *linkAddrEntryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *linkAddrEntryList) Front() *linkAddrEntry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *linkAddrEntryList) Back() *linkAddrEntry { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *linkAddrEntryList) PushFront(e *linkAddrEntry) { + linkAddrEntryElementMapper{}.linkerFor(e).SetNext(l.head) + linkAddrEntryElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + linkAddrEntryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *linkAddrEntryList) PushBack(e *linkAddrEntry) { + linkAddrEntryElementMapper{}.linkerFor(e).SetNext(nil) + linkAddrEntryElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + linkAddrEntryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *linkAddrEntryList) PushBackList(m *linkAddrEntryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + linkAddrEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + linkAddrEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *linkAddrEntryList) InsertAfter(b, e *linkAddrEntry) { + a := linkAddrEntryElementMapper{}.linkerFor(b).Next() + linkAddrEntryElementMapper{}.linkerFor(e).SetNext(a) + linkAddrEntryElementMapper{}.linkerFor(e).SetPrev(b) + linkAddrEntryElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + linkAddrEntryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *linkAddrEntryList) InsertBefore(a, e *linkAddrEntry) { + b := linkAddrEntryElementMapper{}.linkerFor(a).Prev() + linkAddrEntryElementMapper{}.linkerFor(e).SetNext(a) + linkAddrEntryElementMapper{}.linkerFor(e).SetPrev(b) + linkAddrEntryElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + linkAddrEntryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *linkAddrEntryList) Remove(e *linkAddrEntry) { + prev := linkAddrEntryElementMapper{}.linkerFor(e).Prev() + next := linkAddrEntryElementMapper{}.linkerFor(e).Next() + + if prev != nil { + linkAddrEntryElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + linkAddrEntryElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type linkAddrEntryEntry struct { + next *linkAddrEntry + prev *linkAddrEntry +} + +// Next returns the entry that follows e in the list. +func (e *linkAddrEntryEntry) Next() *linkAddrEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *linkAddrEntryEntry) Prev() *linkAddrEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *linkAddrEntryEntry) SetNext(elem *linkAddrEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *linkAddrEntryEntry) SetPrev(elem *linkAddrEntry) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 060a2e7c6..060a2e7c6 100644..100755 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go deleted file mode 100644 index 8d811eb8e..000000000 --- a/pkg/tcpip/stack/ndp_test.go +++ /dev/null @@ -1,2249 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "encoding/binary" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" - linkAddr1 = "\x02\x02\x03\x04\x05\x06" - linkAddr2 = "\x02\x02\x03\x04\x05\x07" - linkAddr3 = "\x02\x02\x03\x04\x05\x08" - defaultTimeout = 100 * time.Millisecond -) - -var ( - llAddr1 = header.LinkLocalAddr(linkAddr1) - llAddr2 = header.LinkLocalAddr(linkAddr2) - llAddr3 = header.LinkLocalAddr(linkAddr3) -) - -// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent -// tcpip.Subnet, and an address where the lower half of the address is composed -// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. -func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { - prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixBytes), - PrefixLen: 64, - } - - subnet := prefix.Subnet() - - var addr tcpip.AddressWithPrefix - if header.IsValidUnicastEthernetAddress(linkAddr) { - addrBytes := []byte(subnet.ID()) - header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } - } - - return prefix, subnet, addr -} - -// TestDADDisabled tests that an address successfully resolves immediately -// when DAD is not enabled (the default for an empty stack.Options). -func TestDADDisabled(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - } - - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - // Should get the address immediately since we should not have performed - // DAD on it. - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) - } - - // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { - t.Fatalf("got NeighborSolicit = %d, want = 0", got) - } -} - -// ndpDADEvent is a set of parameters that was passed to -// ndpDispatcher.OnDuplicateAddressDetectionStatus. -type ndpDADEvent struct { - nicID tcpip.NICID - addr tcpip.Address - resolved bool - err *tcpip.Error -} - -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool -} - -type ndpPrefixEvent struct { - nicID tcpip.NICID - prefix tcpip.Subnet - // true if prefix was discovered, false if invalidated. - discovered bool -} - -type ndpAutoGenAddrEventType int - -const ( - newAddr ndpAutoGenAddrEventType = iota - invalidatedAddr -) - -type ndpAutoGenAddrEvent struct { - nicID tcpip.NICID - addr tcpip.AddressWithPrefix - eventType ndpAutoGenAddrEventType -} - -type ndpRDNSS struct { - addrs []tcpip.Address - lifetime time.Duration -} - -type ndpRDNSSEvent struct { - nicID tcpip.NICID - rdnss ndpRDNSS -} - -var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) - -// ndpDispatcher implements NDPDispatcher so tests can know when various NDP -// related events happen for test purposes. -type ndpDispatcher struct { - dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool - prefixC chan ndpPrefixEvent - rememberPrefix bool - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent - routeTable []tcpip.Route -} - -// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { - if n.dadC != nil { - n.dadC <- ndpDADEvent{ - nicID, - addr, - resolved, - err, - } - } -} - -// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) { - if n.routerC != nil { - n.routerC <- ndpRouterEvent{ - nicID, - addr, - true, - } - } - - if !n.rememberRouter { - return false, nil - } - - rt := append([]tcpip.Route(nil), n.routeTable...) - rt = append(rt, tcpip.Route{ - Destination: header.IPv6EmptySubnet, - Gateway: addr, - NIC: nicID, - }) - n.routeTable = rt - return true, rt -} - -// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route { - if n.routerC != nil { - n.routerC <- ndpRouterEvent{ - nicID, - addr, - false, - } - } - - var rt []tcpip.Route - exclude := tcpip.Route{ - Destination: header.IPv6EmptySubnet, - Gateway: addr, - NIC: nicID, - } - - for _, r := range n.routeTable { - if r != exclude { - rt = append(rt, r) - } - } - n.routeTable = rt - return rt -} - -// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) { - if n.prefixC != nil { - n.prefixC <- ndpPrefixEvent{ - nicID, - prefix, - true, - } - } - - if !n.rememberPrefix { - return false, nil - } - - rt := append([]tcpip.Route(nil), n.routeTable...) - rt = append(rt, tcpip.Route{ - Destination: prefix, - NIC: nicID, - }) - n.routeTable = rt - return true, rt -} - -// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. -func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route { - if n.prefixC != nil { - n.prefixC <- ndpPrefixEvent{ - nicID, - prefix, - false, - } - } - - var rt []tcpip.Route - exclude := tcpip.Route{ - Destination: prefix, - NIC: nicID, - } - - for _, r := range n.routeTable { - if r != exclude { - rt = append(rt, r) - } - } - n.routeTable = rt - return rt -} - -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { - if n.autoGenAddrC != nil { - n.autoGenAddrC <- ndpAutoGenAddrEvent{ - nicID, - addr, - newAddr, - } - } - return true -} - -func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if n.autoGenAddrC != nil { - n.autoGenAddrC <- ndpAutoGenAddrEvent{ - nicID, - addr, - invalidatedAddr, - } - } -} - -// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. -func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { - if n.rdnssC != nil { - n.rdnssC <- ndpRDNSSEvent{ - nicID, - ndpRDNSS{ - addrs, - lifetime, - }, - } - } -} - -// TestDADResolve tests that an address successfully resolves after performing -// DAD for various values of DupAddrDetectTransmits and RetransmitTimer. -// Included in the subtests is a test to make sure that an invalid -// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. -func TestDADResolve(t *testing.T) { - tests := []struct { - name string - dupAddrDetectTransmits uint8 - retransTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - {"1:1s:1s", 1, time.Second, time.Second}, - {"2:1s:1s", 2, time.Second, time.Second}, - {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second}, - // 0s is an invalid RetransmitTimer timer and will be fixed to - // the default RetransmitTimer value of 1s. - {"1:0s:1s", 1, 0, time.Second}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = test.retransTimer - opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit - - // Should have sent an NDP NS immediately. - if got := stat.Value(); got != 1 { - t.Fatalf("got NeighborSolicit = %d, want = 1", got) - - } - - // Address should not be considered bound to the NIC yet - // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - // Wait for the remaining time - some delta (500ms), to - // make sure the address is still not resolved. - const delta = 500 * time.Millisecond - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - // Wait for DAD to resolve. - select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") - } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) - } - - // Should not have sent any more NS messages. - if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) { - t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) - } - - // Validate the sent Neighbor Solicitation messages. - for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p := <-e.C - - // Make sure its an IPv6 packet. - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Check NDP packet. - checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(addr1))) - } - }) - } - -} - -// TestDADFail tests to make sure that the DAD process fails if another node is -// detected to be performing DAD on the same address (receive an NS message from -// a node doing DAD for the same address), or if another node is detected to own -// the address already (receive an NA message for the tentative address). -func TestDADFail(t *testing.T) { - tests := []struct { - name string - makeBuf func(tgt tcpip.Address) buffer.Prependable - getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - }{ - { - "RxSolicit", - func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - - return hdr - - }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborSolicit - }, - }, - { - "RxAdvert", - func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - return hdr - - }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborAdvert - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = time.Second * 2 - - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet - // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - // Receive a packet to simulate multiple nodes owning or - // attempting to own the same address. - hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) - if got := stat.Value(); got != 1 { - t.Fatalf("got stat = %d, want = 1", got) - } - - // Wait for DAD to fail and make sure the address did - // not get resolved. - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if e.resolved { - t.Fatal("got DAD event w/ resolved = true, want = false") - } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - }) - } -} - -// TestDADStop tests to make sure that the DAD process stops when an address is -// removed. -func TestDADStop(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - ndpConfigs := stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } - - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - // Remove the address. This should stop DAD. - if err := s.RemoveAddress(1, addr1); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) - } - - // Wait for DAD to fail (since the address was removed during DAD). - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if e.resolved { - t.Fatal("got DAD event w/ resolved = true, want = false") - } - - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Fatalf("got NeighborSolicit = %d, want <= 1", got) - } -} - -// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NDP configurations using an invalid NICID. -func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - - // No NIC with ID 1 yet. - if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) - } -} - -// TestSetNDPConfigurations tests that we can update and use per-interface NDP -// configurations without affecting the default NDP configurations or other -// interfaces' configurations. -func TestSetNDPConfigurations(t *testing.T) { - tests := []struct { - name string - dupAddrDetectTransmits uint8 - retransmitTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - "OK", - 1, - time.Second, - time.Second, - }, - { - "Invalid Retransmit Timer", - 1, - 0, - time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - }) - - // This NIC(1)'s NDP configurations will be updated to - // be different from the default. - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Created before updating NIC(1)'s NDP configurations - // but updating NIC(1)'s NDP configurations should not - // affect other existing NICs. - if err := s.CreateNIC(2, e); err != nil { - t.Fatalf("CreateNIC(2) = %s", err) - } - - // Update the NDP configurations on NIC(1) to use DAD. - configs := stack.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - RetransmitTimer: test.retransmitTimer, - } - if err := s.SetNDPConfigurations(1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) - } - - // Created after updating NIC(1)'s NDP configurations - // but the stack's default NDP configurations should not - // have been updated. - if err := s.CreateNIC(3, e); err != nil { - t.Fatalf("CreateNIC(3) = %s", err) - } - - // Add addresses for each NIC. - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) - } - if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) - } - - // Address should not be considered bound to NIC(1) yet - // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - // Should get the address on NIC(2) and NIC(3) - // immediately since we should not have performed DAD on - // it as the stack was configured to not do DAD by - // default and we only updated the NDP configurations on - // NIC(1). - addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) - } - if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) - } - addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) - } - if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) - } - - // Sleep until right (500ms before) before resolution to - // make sure the address didn't resolve on NIC(1) yet. - const delta = 500 * time.Millisecond - time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - // Wait for DAD to resolve. - select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") - } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) - } - }) - } -} - -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(0) - ra := header.NDPRouterAdvert(pkt.NDPPayload()) - opts := ra.Options() - opts.Serialize(optSer) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} -} - -// raBuf returns a valid NDP Router Advertisement. -// -// Note, raBuf does not populate any of the RA fields other than the -// Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) -} - -// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix -// Information option. -// -// Note, raBufWithPI does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { - flags := uint8(0) - if onLink { - // The OnLink flag is the 7th bit in the flags byte. - flags |= 1 << 7 - } - if auto { - // The Address Auto-Configuration flag is the 6th bit in the - // flags byte. - flags |= 1 << 6 - } - - // A valid header.NDPPrefixInformation must be 30 bytes. - buf := [30]byte{} - // The first byte in a header.NDPPrefixInformation is the Prefix Length - // field. - buf[0] = uint8(prefix.PrefixLen) - // The 2nd byte within a header.NDPPrefixInformation is the Flags field. - buf[1] = flags - // The Valid Lifetime field starts after the 2nd byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[2:], vl) - // The Preferred Lifetime field starts after the 6th byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[6:], pl) - // The Prefix Address field starts after the 14th byte within a - // header.NDPPrefixInformation. - copy(buf[14:], prefix.Address) - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ - header.NDPPrefixInformation(buf[:]), - }) -} - -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - t.Parallel() - - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - }) - s.SetForwarding(forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - case <-time.After(defaultTimeout): - } - }) - } -} - -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - routeTable := []tcpip.Route{ - { - header.IPv6EmptySubnet, - llAddr3, - 1, - }, - } - s.SetRouteTable(routeTable) - - // Rx an RA with short lifetime. - lifetime := time.Duration(1) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime))) - select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != llAddr2 { - t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for router discovery event") - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) - } - - // Wait for the normal invalidation time plus an extra second to - // make sure we do not actually receive any invalidation events as - // we should not have remembered the router in the first place. - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - case <-time.After(lifetime*time.Second + defaultTimeout): - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) - } -} - -func TestRouterDiscovery(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), - rememberRouter: true, - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - }) - - waitForEvent := func(addr tcpip.Address, discovered bool, timeout time.Duration) { - t.Helper() - - select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.discovered != discovered { - t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for router discovery event") - } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - case <-time.After(defaultTimeout): - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - waitForEvent(llAddr2, true, defaultTimeout) - - // Should have a default route through the discovered router. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Rx an RA from another router (lladdr3) with non-zero lifetime. - l3Lifetime := time.Duration(6) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) - waitForEvent(llAddr3, true, defaultTimeout) - - // Should have default routes through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Rx an RA from lladdr2 with lesser lifetime. - l2Lifetime := time.Duration(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - case <-time.After(defaultTimeout): - } - - // Should still have a default route through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Wait for lladdr2's router invalidation timer to fire. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - waitForEvent(llAddr2, false, l2Lifetime*time.Second+defaultTimeout) - - // Should no longer have the default route through lladdr2. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - waitForEvent(llAddr2, true, defaultTimeout) - - // Should have a default route through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - waitForEvent(llAddr2, false, defaultTimeout) - - // Should have deleted the default route through the router that just - // got invalidated. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Wait for lladdr3's router invalidation timer to fire. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - waitForEvent(llAddr3, false, l3Lifetime*time.Second+defaultTimeout) - - // Should not have any routes now that all discovered routers have been - // invalidated. - if got := len(s.GetRouteTable()); got != 0 { - t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) - } -} - -// TestRouterDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. -func TestRouterDiscoveryMaxRouters(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), - rememberRouter: true, - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectedRt := [stack.MaxDiscoveredDefaultRouters]tcpip.Route{} - - // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { - linkAddr := []byte{2, 2, 3, 4, 5, 0} - linkAddr[5] = byte(i) - llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) - - if i <= stack.MaxDiscoveredDefaultRouters { - expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1} - select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != llAddr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for router discovery event") - } - - } else { - select { - case <-ndpDisp.routerC: - t.Fatal("should not have discovered a new router after we already discovered the max number of routers") - case <-time.After(defaultTimeout): - } - } - } - - // Should only have default routes for the first - // stack.MaxDiscoveredDefaultRouters discovered routers. - if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) - } -} - -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - t.Parallel() - - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - }) - s.SetForwarding(forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - case <-time.After(defaultTimeout): - } - }) - } -} - -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - t.Parallel() - - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - routeTable := []tcpip.Route{ - { - header.IPv6EmptySubnet, - llAddr3, - 1, - }, - } - s.SetRouteTable(routeTable) - - // Rx an RA with prefix with a short lifetime. - const lifetime = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetime, 0)) - select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for prefix discovery event") - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) - } - - // Wait for the normal invalidation time plus some buffer to - // make sure we do not actually receive any invalidation events as - // we should not have remembered the prefix in the first place. - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - case <-time.After(lifetime*time.Second + defaultTimeout): - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) - } -} - -func TestPrefixDiscovery(t *testing.T) { - t.Parallel() - - prefix1, subnet1, _ := prefixSubnetAddr(0, "") - prefix2, subnet2, _ := prefixSubnetAddr(1, "") - prefix3, subnet3, _ := prefixSubnetAddr(2, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), - rememberPrefix: true, - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - }) - - waitForEvent := func(subnet tcpip.Subnet, discovered bool, timeout time.Duration) { - t.Helper() - - select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) - } - if r.discovered != discovered { - t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for prefix discovery event") - } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - case <-time.After(defaultTimeout): - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - waitForEvent(subnet1, true, defaultTimeout) - - // Should have added a device route for subnet1 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - waitForEvent(subnet2, true, defaultTimeout) - - // Should have added a device route for subnet2 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - waitForEvent(subnet3, true, defaultTimeout) - - // Should have added a device route for subnet3 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - waitForEvent(subnet1, false, defaultTimeout) - - // Should have removed the device route for subnet1 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - case <-time.After(defaultTimeout): - } - - // Should not have updated route table. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Wait for prefix2's most recent invalidation timer plus some buffer to - // expire. - waitForEvent(subnet2, false, time.Duration(lifetime)*time.Second+defaultTimeout) - - // Should have removed the device route for subnet2 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } - - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - waitForEvent(subnet3, false, defaultTimeout) - - // Should not have any routes. - if got := len(s.GetRouteTable()); got != 0 { - t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) - } -} - -func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - subnet := prefix.Subnet() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), - rememberPrefix: true, - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - }) - - waitForEvent := func(discovered bool, timeout time.Duration) { - t.Helper() - - select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Errorf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Errorf("got r.prefix = %s, want = %s", r.prefix, subnet) - } - if r.discovered != discovered { - t.Errorf("got r.discovered = %t, want = %t", r.discovered, discovered) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for prefix discovery event") - } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix in an NDP Prefix Information option (PI) - // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - waitForEvent(true, defaultTimeout) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): - } - - // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - waitForEvent(false, testInfiniteLifetime) - - // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - waitForEvent(true, defaultTimeout) - - // Receive an RA with prefix with an infinite lifetime. - // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): - } - - // Receive an RA with 0 lifetime. - // The prefix should get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) - waitForEvent(false, defaultTimeout) -} - -// TestPrefixDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. -func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - expectedRt := [stack.MaxDiscoveredOnLinkPrefixes]tcpip.Route{} - prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} - - // Receive an RA with 2 more than the max number of discovered on-link - // prefixes. - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { - prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} - prefixAddr[7] = byte(i) - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixAddr[:]), - PrefixLen: 64, - } - prefixes[i] = prefix.Subnet() - buf := [30]byte{} - buf[0] = uint8(prefix.PrefixLen) - buf[1] = 128 - binary.BigEndian.PutUint32(buf[2:], 10) - copy(buf[14:], prefix.Address) - - optSer[i] = header.NDPPrefixInformation(buf[:]) - - if i < stack.MaxDiscoveredOnLinkPrefixes { - expectedRt[i] = tcpip.Route{prefixes[i], tcpip.Address([]byte(nil)), 1} - } - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < stack.MaxDiscoveredOnLinkPrefixes { - select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != prefixes[i] { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, prefixes[i]) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for prefix discovery event") - } - } else { - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") - case <-time.After(defaultTimeout): - } - } - } - - // Should only have device routes for the first - // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes. - if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) - } -} - -// Checks to see if list contains an IPv6 address, item. -func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: item, - } - - for _, i := range list { - if i == protocolAddress { - return true - } - } - - return false -} - -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - t.Parallel() - - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - }) - s.SetForwarding(forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - case <-time.After(defaultTimeout): - } - }) - } -} - -// TestAutoGenAddr tests that an address is properly generated and invalidated -// when configured to do so. -func TestAutoGenAddr(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved - }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - waitForEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != eventType { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, eventType) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for addr auto gen event") - } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - case <-time.After(defaultTimeout): - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - waitForEvent(addr1, newAddr, defaultTimeout) - if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - case <-time.After(defaultTimeout): - } - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - waitForEvent(addr2, newAddr, defaultTimeout) - if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } - - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - case <-time.After(defaultTimeout): - } - - // Wait for addr of prefix1 to be invalidated. - waitForEvent(addr1, invalidatedAddr, newMinVLDuration+defaultTimeout) - if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } -} - -// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an -// auto-generated address only gets updated when required to, as specified in -// RFC 4862 section 5.5.3.e. -func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { - const infiniteVL = 4294967295 - const newMinVL = 5 - saved := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved - }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - ovl uint32 - nvl uint32 - evl uint32 - }{ - // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than - // it. - { - "LargeVLToVLLessThanMinVLForUpdate", - 9999, - 1, - newMinVL, - }, - { - "LargeVLTo0", - 9999, - 0, - newMinVL, - }, - { - "InfiniteVLToVLLessThanMinVLForUpdate", - infiniteVL, - 1, - newMinVL, - }, - { - "InfiniteVLTo0", - infiniteVL, - 0, - newMinVL, - }, - - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. - { - "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, - }, - - // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. - { - "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, - }, - { - "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, - }, - { - "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, - }, - } - - const delta = 500 * time.Millisecond - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix with initial VL, test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != newAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - - // Receive an new RA with prefix with new VL, test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - - // - // Validate that the VL for the address got set to - // test.evl. - // - - // Make sure we do not get any invalidation events - // until atleast 500ms (delta) before test.evl. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - delta): - } - - // Wait for another second (2x delta), but now we expect - // the invalidation event. - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != invalidatedAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) - } - case <-time.After(2 * delta): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } -} - -// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed -// by the user, its resources will be cleaned up and an invalidation event will -// be sent to the integrator. -func TestAutoGenAddrRemoval(t *testing.T) { - t.Parallel() - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix with its valid lifetime = lifetime. - const lifetime = 5 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != newAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - - // Remove the address. - if err := s.RemoveAddress(1, addr.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) - } - - // Should get the invalidation event immediately. - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != invalidatedAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - - // Wait for the original valid lifetime to make sure the original timer - // got stopped/cleaned up. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly received an auto gen addr event") - case <-time.After(lifetime*time.Second + defaultTimeout): - } -} - -// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that -// is already assigned to the NIC, the static address remains. -func TestAutoGenAddrStaticConflict(t *testing.T) { - t.Parallel() - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) - } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive a PI where the generated address will be the same as the one - // that we already have assigned statically. - const lifetime = 5 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") - case <-time.After(defaultTimeout): - } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Should not get an invalidation event after the PI's invalidation - // time. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetime*time.Second + defaultTimeout): - } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } -} - -// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event -// to the integrator when an RA is received with the NDP Recursive DNS Server -// option with at least one valid address. -func TestNDPRecursiveDNSServerDispatch(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - opt header.NDPRecursiveDNSServer - expected *ndpRDNSS - }{ - { - "Unspecified", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }), - nil, - }, - { - "Multicast", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }), - nil, - }, - { - "OptionTooSmall", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, - }), - nil, - }, - { - "0Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - }), - nil, - }, - { - "Valid1Address", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - }, - 2 * time.Second, - }, - }, - { - "Valid2Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - }, - time.Second, - }, - }, - { - "Valid3Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 0, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", - }, - 0, - }, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - // We do not expect more than a single RDNSS - // event at any time for this test. - rdnssC: make(chan ndpRDNSSEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) - - if test.expected != nil { - select { - case e := <-ndpDisp.rdnssC: - if e.nicID != 1 { - t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) - } - if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { - t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) - } - if e.rdnss.lifetime != test.expected.lifetime { - t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) - } - default: - t.Fatal("expected an RDNSS option event") - } - } - - // Should have no more RDNSS options. - select { - case e := <-ndpDisp.rdnssC: - t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) - default: - } - }) - } -} diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go new file mode 100755 index 000000000..2551126f2 --- /dev/null +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -0,0 +1,125 @@ +// automatically generated by stateify. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *linkAddrEntryList) beforeSave() {} +func (x *linkAddrEntryList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *linkAddrEntryList) afterLoad() {} +func (x *linkAddrEntryList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *linkAddrEntryEntry) beforeSave() {} +func (x *linkAddrEntryEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *linkAddrEntryEntry) afterLoad() {} +func (x *linkAddrEntryEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *TransportEndpointID) beforeSave() {} +func (x *TransportEndpointID) save(m state.Map) { + x.beforeSave() + m.Save("LocalPort", &x.LocalPort) + m.Save("LocalAddress", &x.LocalAddress) + m.Save("RemotePort", &x.RemotePort) + m.Save("RemoteAddress", &x.RemoteAddress) +} + +func (x *TransportEndpointID) afterLoad() {} +func (x *TransportEndpointID) load(m state.Map) { + m.Load("LocalPort", &x.LocalPort) + m.Load("LocalAddress", &x.LocalAddress) + m.Load("RemotePort", &x.RemotePort) + m.Load("RemoteAddress", &x.RemoteAddress) +} + +func (x *GSOType) save(m state.Map) { + m.SaveValue("", (int)(*x)) +} + +func (x *GSOType) load(m state.Map) { + m.LoadValue("", new(int), func(y interface{}) { *x = (GSOType)(y.(int)) }) +} + +func (x *GSO) beforeSave() {} +func (x *GSO) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("NeedsCsum", &x.NeedsCsum) + m.Save("CsumOffset", &x.CsumOffset) + m.Save("MSS", &x.MSS) + m.Save("L3HdrLen", &x.L3HdrLen) + m.Save("MaxSize", &x.MaxSize) +} + +func (x *GSO) afterLoad() {} +func (x *GSO) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("NeedsCsum", &x.NeedsCsum) + m.Load("CsumOffset", &x.CsumOffset) + m.Load("MSS", &x.MSS) + m.Load("L3HdrLen", &x.L3HdrLen) + m.Load("MaxSize", &x.MaxSize) +} + +func (x *TransportEndpointInfo) beforeSave() {} +func (x *TransportEndpointInfo) save(m state.Map) { + x.beforeSave() + m.Save("NetProto", &x.NetProto) + m.Save("TransProto", &x.TransProto) + m.Save("ID", &x.ID) + m.Save("BindNICID", &x.BindNICID) + m.Save("BindAddr", &x.BindAddr) + m.Save("RegisterNICID", &x.RegisterNICID) +} + +func (x *TransportEndpointInfo) afterLoad() {} +func (x *TransportEndpointInfo) load(m state.Map) { + m.Load("NetProto", &x.NetProto) + m.Load("TransProto", &x.TransProto) + m.Load("ID", &x.ID) + m.Load("BindNICID", &x.BindNICID) + m.Load("BindAddr", &x.BindAddr) + m.Load("RegisterNICID", &x.RegisterNICID) +} + +func (x *multiPortEndpoint) beforeSave() {} +func (x *multiPortEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("endpointsArr", &x.endpointsArr) + m.Save("endpointsMap", &x.endpointsMap) + m.Save("reuse", &x.reuse) +} + +func (x *multiPortEndpoint) afterLoad() {} +func (x *multiPortEndpoint) load(m state.Map) { + m.Load("endpointsArr", &x.endpointsArr) + m.Load("endpointsMap", &x.endpointsMap) + m.Load("reuse", &x.reuse) +} + +func init() { + state.Register("stack.linkAddrEntryList", (*linkAddrEntryList)(nil), state.Fns{Save: (*linkAddrEntryList).save, Load: (*linkAddrEntryList).load}) + state.Register("stack.linkAddrEntryEntry", (*linkAddrEntryEntry)(nil), state.Fns{Save: (*linkAddrEntryEntry).save, Load: (*linkAddrEntryEntry).load}) + state.Register("stack.TransportEndpointID", (*TransportEndpointID)(nil), state.Fns{Save: (*TransportEndpointID).save, Load: (*TransportEndpointID).load}) + state.Register("stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load}) + state.Register("stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load}) + state.Register("stack.TransportEndpointInfo", (*TransportEndpointInfo)(nil), state.Fns{Save: (*TransportEndpointInfo).save, Load: (*TransportEndpointInfo).load}) + state.Register("stack.multiPortEndpoint", (*multiPortEndpoint)(nil), state.Fns{Save: (*multiPortEndpoint).save, Load: (*multiPortEndpoint).load}) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go deleted file mode 100644 index 8fc034ca1..000000000 --- a/pkg/tcpip/stack/stack_test.go +++ /dev/null @@ -1,2188 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package stack_test contains tests for the stack. It is in its own package so -// that the tests can also validate that all definitions needed to implement -// transport and network protocols are properly exported by the stack package. -package stack_test - -import ( - "bytes" - "fmt" - "math" - "sort" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fakeNetHeaderLen = 12 - fakeDefaultPrefixLen = 8 - - // fakeControlProtocol is used for control packets that represent - // destination port unreachable. - fakeControlProtocol tcpip.TransportProtocolNumber = 2 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fakeNetworkEndpoint struct { - nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int - proto *fakeNetworkProtocol - dispatcher stack.TransportDispatcher - ep stack.LinkEndpoint -} - -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) -} - -func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID -} - -func (f *fakeNetworkEndpoint) PrefixLen() int { - return f.prefixLen -} - -func (*fakeNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { - return &f.id -} - -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { - // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - - // Consume the network header. - b := pkt.Data.First() - pkt.Data.TrimFront(fakeNetHeaderLen) - - // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { - return - } - - pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) - return - } - - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) -} - -func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fakeNetHeaderLen -} - -func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - -func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.ep.Capabilities() -} - -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { - // Increment the sent packet count in the protocol descriptor. - f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ - - // Add the protocol's header to the packet and send it to the link - // endpoint. - b := pkt.Header.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) - - if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) - } - if loop&stack.PacketOut == 0 { - return nil - } - - return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - panic("not implemented") -} - -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported -} - -func (*fakeNetworkEndpoint) Close() {} - -type fakeNetGoodOption bool - -type fakeNetBadOption bool - -type fakeNetInvalidValueOption int - -type fakeNetOptions struct { - good bool -} - -// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the -// number of packets sent and received via endpoints of this protocol. The index -// where packets are added is given by the packet's destination address MOD 10. -type fakeNetworkProtocol struct { - packetCount [10]int - sendPacketCount [10]int - opts fakeNetOptions -} - -func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fakeNetNumber -} - -func (f *fakeNetworkProtocol) MinimumPacketSize() int { - return fakeNetHeaderLen -} - -func (f *fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - -func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { - return f.packetCount[int(intfAddr)%len(f.packetCount)] -} - -func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) -} - -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { - return &fakeNetworkEndpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - proto: f, - dispatcher: dispatcher, - ep: ep, - }, nil -} - -func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case fakeNetGoodOption: - f.opts.good = bool(v) - return nil - case fakeNetInvalidValueOption: - return tcpip.ErrInvalidOptionValue - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *fakeNetGoodOption: - *v = fakeNetGoodOption(f.opts.good) - return nil - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func fakeNetFactory() stack.NetworkProtocol { - return &fakeNetworkProtocol{} -} - -func TestNetworkReceive(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and two - // addresses attached to it: 1 & 2. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Make sure packet with wrong address is not delivered. - buf[0] = 3 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to first endpoint. - buf[0] = 1 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to second endpoint. - buf[0] = 2 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet that is too small is dropped. - buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } -} - -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { - r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - return send(r, payload) -} - -func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }) -} - -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := sendTo(s, addr, payload); err != nil { - t.Error("sendTo failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("sendTo packet count: got = %d, want %d", got, want) - } -} - -func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := send(r, payload); err != nil { - t.Error("send failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("send packet count: got = %d, want %d", got, want) - } -} - -func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { - t.Helper() - if gotErr := send(r, payload); gotErr != wantErr { - t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { - t.Helper() - if gotErr := sendTo(s, addr, payload); gotErr != wantErr { - t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we expect to receive it. - want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we do NOT expect to receive it. - want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { - t.Helper() - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if got := fakeNet.PacketCount(localAddrByte); got != want { - t.Errorf("receive packet count: got = %d, want %d", got, want) - } -} - -func TestNetworkSend(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and one - // address: 1. The route table sends all packets through the only - // existing nic. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("NewNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", ep, nil) -} - -func TestNetworkSendMultiRoute(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Send a packet to an odd destination. - testSendTo(t, s, "\x05", ep1, nil) - - // Send a packet to an even destination. - testSendTo(t, s, "\x06", ep2, nil) -} - -func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { - r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - defer r.Release() - - if r.LocalAddress != expectedSrcAddr { - t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress) - } - - if r.RemoteAddress != dstAddr { - t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress) - } -} - -func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { - _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute) - } -} - -func TestRoutes(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Test routes to odd address. - testRoute(t, s, 0, "", "\x05", "\x01") - testRoute(t, s, 0, "\x01", "\x05", "\x01") - testRoute(t, s, 1, "\x01", "\x05", "\x01") - testRoute(t, s, 0, "\x03", "\x05", "\x03") - testRoute(t, s, 1, "\x03", "\x05", "\x03") - - // Test routes to even address. - testRoute(t, s, 0, "", "\x06", "\x02") - testRoute(t, s, 0, "\x02", "\x06", "\x02") - testRoute(t, s, 2, "\x02", "\x06", "\x02") - testRoute(t, s, 0, "\x04", "\x06", "\x04") - testRoute(t, s, 2, "\x04", "\x06", "\x04") - - // Try to send to odd numbered address from even numbered ones, then - // vice-versa. - testNoRoute(t, s, 0, "\x02", "\x05") - testNoRoute(t, s, 2, "\x02", "\x05") - testNoRoute(t, s, 0, "\x04", "\x05") - testNoRoute(t, s, 2, "\x04", "\x05") - - testNoRoute(t, s, 0, "\x01", "\x06") - testNoRoute(t, s, 1, "\x01", "\x06") - testNoRoute(t, s, 0, "\x03", "\x06") - testNoRoute(t, s, 1, "\x03", "\x06") -} - -func TestAddressRemoval(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Send and receive packets, and verify they are received. - buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - - // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) - } -} - -func TestAddressRemovalWithRouteHeld(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - // Send and receive packets, and verify they are received. - buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - - // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) - } -} - -func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { - t.Helper() - info, ok := s.NICInfo()[nicID] - if !ok { - t.Fatalf("NICInfo() failed to find nicID=%d", nicID) - } - if len(addr) == 0 { - // No address given, verify that there is no address assigned to the NIC. - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) - } - } - return - } - // Address given, verify the address is assigned to the NIC and no other - // address is. - found := false - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber { - if a.AddressWithPrefix.Address == addr { - found = true - } else { - t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) - } - } - } - if !found { - t.Errorf("verify address: couldn't find %s on the NIC", addr) - } -} - -func TestEndpointExpiration(t *testing.T) { - const ( - localAddrByte byte = 0x01 - remoteAddr tcpip.Address = "\x03" - noAddr tcpip.Address = "" - nicID tcpip.NICID = 1 - ) - localAddr := tcpip.Address([]byte{localAddrByte}) - - for _, promiscuous := range []bool{true, false} { - for _, spoofing := range []bool{true, false} { - t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - buf[0] = localAddrByte - - if promiscuous { - if err := s.SetPromiscuousMode(nicID, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - } - - if spoofing { - if err := s.SetSpoofing(nicID, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - } - - // 1. No Address yet, send should only work for spoofing, receive for - // promiscuous mode. - //----------------------- - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - } - - // 2. Add Address, everything should work. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 3. Remove the address, send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - } - - // 4. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 5. Take a reference to the endpoint by getting a route. Verify that - // we can still send/receive, including sending using the route. - //----------------------- - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 6. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - } - - // 7. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 8. Remove the route, sendTo/recv should still work. - //----------------------- - r.Release() - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 9. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) - } - }) - } - } -} - -func TestPromiscuousMode(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it doesn't get delivered as we don't - // have a matching endpoint. - const localAddrByte byte = 0x01 - buf[0] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - - // Set promiscuous mode, then check that packet is delivered. - if err := s.SetPromiscuousMode(1, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - - // Check that we can't get a route as there is no local address. - _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute) - } - - // Set promiscuous mode to false, then check that packet can't be - // delivered anymore. - if err := s.SetPromiscuousMode(1, false); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func TestSpoofingWithAddress(t *testing.T) { - localAddr := tcpip.Address("\x01") - nonExistentLocalAddr := tcpip.Address("\x02") - dstAddr := tcpip.Address("\x03") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet works. - testSendTo(t, s, dstAddr, ep, nil) - testSend(t, r, ep, nil) - - // FindRoute should also work with a local address that exists on the NIC. - r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != localAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet using the route works. - testSend(t, r, ep, nil) -} - -func TestSpoofingNoAddress(t *testing.T) { - nonExistentLocalAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) - } - // Sending a packet works. - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) -} - -func verifyRoute(gotRoute, wantRoute stack.Route) error { - if gotRoute.LocalAddress != wantRoute.LocalAddress { - return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) - } - if gotRoute.RemoteAddress != wantRoute.RemoteAddress { - return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) - } - if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { - return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) - } - if gotRoute.NextHop != wantRoute.NextHop { - return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) - } - return nil -} - -func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - s.SetRouteTable([]tcpip.Route{}) - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) - } - - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) - } - r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) - } -} - -func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} - // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} - nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") - // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} - nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") - - // Create a new stack with two NICs. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - if err := s.CreateNIC(2, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) - } - - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) - } - - // Set the initial route table. - rt := []tcpip.Route{ - {Destination: nic1Addr.Subnet(), NIC: 1}, - {Destination: nic2Addr.Subnet(), NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, - } - s.SetRouteTable(rt) - - // When an interface is given, the route for a broadcast goes through it. - r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) - } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) - } - - // When an interface is not given, it consults the route table. - // 1. Case: Using the default route. - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) - } - - // 2. Case: Having an explicit route for broadcast will select that one. - rt = append( - []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, - }, - rt..., - ) - s.SetRouteTable(rt) - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) - } -} - -func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { - for _, tc := range []struct { - name string - routeNeeded bool - address tcpip.Address - }{ - // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255 - // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff - {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"}, - {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"}, - {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"}, - {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"}, - {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"}, - - // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112] - {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 link-local address starts with fe80::/10. - {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"}, - {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 addresses that are neither multicast nor link-local. - {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - } { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - s.SetRouteTable([]tcpip.Route{}) - - var anyAddr tcpip.Address - if len(tc.address) == header.IPv4AddressSize { - anyAddr = header.IPv4Any - } else { - anyAddr = header.IPv6Any - } - - want := tcpip.ErrNetworkUnreachable - if tc.routeNeeded { - want = tcpip.ErrNoRoute - } - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) - } - - if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { - // Route table is empty but we need a route, this should cause an error. - if err != tcpip.ErrNoRoute { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute) - } - } else { - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err) - } - if r.LocalAddress != anyAddr { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, anyAddr) - } - if r.RemoteAddress != tc.address { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, tc.address) - } - } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - }) - } -} - -// Add a range of addresses, then check that a packet is delivered. -func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[0] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { - t.Helper() - - // Loop over all addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - - addrBytes := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - addr := tcpip.Address(addrBytes) - wantNicID := nicID - // The subnet and broadcast addresses are skipped. - if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { - wantNicID = 0 - } - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) - } - addrBytes[0]++ - } - - // Trying the next address should always fail since it is outside the range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) - } -} - -// Set a range of addresses, then remove it again, and check at each step that -// CheckLocalAddress returns the correct NIC for each address or zero if not -// existent. -func TestCheckLocalAddressForSubnet(t *testing.T) { - const nicID tcpip.NICID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) - - if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) - - if err := s.RemoveAddressRange(nicID, subnet); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) -} - -// Set a range of addresses, then send a packet to a destination outside the -// range and then check it doesn't get delivered. -func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[0] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func TestNetworkOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{}, - }) - - // Try an unsupported network protocol. - if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.NetworkProtocol) - }{ - {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { - t.Helper() - fakeNet := p.(*fakeNetworkProtocol) - if fakeNet.opts.good != true { - t.Fatalf("fakeNet.opts.good = false, want = true") - } - var v fakeNetGoodOption - if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) - } - }}, - {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) - } - } -} - -func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { - ranges, ok := s.NICAddressRanges()[id] - if !ok { - return false - } - for _, r := range ranges { - if r == addrRange { - return true - } - } - return false -} - -func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addr := tcpip.Address("\x01\x01\x01\x01") - mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - addrRange, err := tcpip.NewSubnet(addr, mask) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.RemoveAddressRange(1, addrRange); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } -} - -func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { - for _, addrLen := range []int{4, 16} { - t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { - for canBe := 0; canBe < 3; canBe++ { - t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { - for never := 0; never < 3; never++ { - t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - // Insert <canBe> primary and <never> never-primary addresses. - // Each one will add a network endpoint to the NIC. - primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{}) - for i := 0; i < canBe+never; i++ { - var behavior stack.PrimaryEndpointBehavior - if i < canBe { - behavior = stack.CanBePrimaryEndpoint - } else { - behavior = stack.NeverPrimaryEndpoint - } - // Add an address and in case of a primary one include a - // prefixLen. - address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) - if behavior == stack.CanBePrimaryEndpoint { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, - } - if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil { - t.Fatal("AddProtocolAddressWithOptions failed:", err) - } - // Remember the address/prefix. - primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} - } else { - if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - } - } - } - // Check that GetMainNICAddress returns an address if at least - // one primary address was added. In that case make sure the - // address/prefixLen matches what we added. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if len(primaryAddrAdded) == 0 { - // No primary addresses present. - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) - } - } else { - // At least one primary address was added, verify the returned - // address is in the list of primary addresses we added. - if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) - } - } - }) - } - }) - } - }) - } -} - -func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - for _, tc := range []struct { - name string - address tcpip.Address - prefixLen int - }{ - {"IPv4", "\x01\x01\x01\x01", 24}, - {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, - } { - t.Run(tc.name, func(t *testing.T) { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tc.address, - PrefixLen: tc.prefixLen, - }, - } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) - } - - // Check that we get the right initial address and prefix length. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { - t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) - } - - if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - - // Check that we get no address after removal. - gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) - } - }) - } -} - -// Simple network address generator. Good for 255 addresses. -type addressGenerator struct{ cnt byte } - -func (g *addressGenerator) next(addrLen int) tcpip.Address { - g.cnt++ - return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen)) -} - -func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { - t.Helper() - - if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) - } - - sort.Slice(gotAddresses, func(i, j int) bool { - return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address - }) - sort.Slice(expectedAddresses, func(i, j int) bool { - return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address - }) - - for i, gotAddr := range gotAddresses { - expectedAddr := expectedAddresses[i] - if gotAddr != expectedAddr { - t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr) - } - } -} - -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestNICStats(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) - } - - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) - } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) - } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } -} - -func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } - - // Route all packets to address 3 to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) - } - - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - select { - case <-ep2.C: - default: - t.Fatal("Packet not forwarded") - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } -} - -// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// (or lack there-of if disabled (default)). Note, DAD will be disabled in -// these tests. -func TestNICAutoGenAddr(t *testing.T) { - tests := []struct { - name string - autoGen bool - linkAddr tcpip.LinkAddress - shouldGen bool - }{ - { - "Disabled", - false, - linkAddr1, - false, - }, - { - "Enabled", - true, - linkAddr1, - true, - }, - { - "Nil MAC", - true, - tcpip.LinkAddress([]byte(nil)), - false, - }, - { - "Empty MAC", - true, - tcpip.LinkAddress(""), - false, - }, - { - "Invalid MAC", - true, - tcpip.LinkAddress("\x01\x02\x03"), - false, - }, - { - "Multicast MAC", - true, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - false, - }, - { - "Unspecified MAC", - true, - tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - } - - if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when - // test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by - // default. - opts.AutoGenIPv6LinkLocal = true - } - - e := channel.New(10, 1280, test.linkAddr) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - - if test.shouldGen { - // Should have auto-generated an address and - // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) - } - } else { - // Should not have auto-generated an address. - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - } - }) - } -} - -// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 -// link-local addresses will only be assigned after the DAD process resolves. -func TestNICAutoGenAddrDoesDAD(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, - } - - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Address should not be considered bound to the - // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - linkLocalAddr := header.LinkLocalAddr(linkAddr1) - - // Wait for DAD to resolve. - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // We should get a resolution event after 1s (default time to - // resolve as per default NDP configurations). Waiting for that - // resolution time + an extra 1s without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != linkLocalAddr { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") - } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) - } -} - -// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected -// when an address's kind gets "promoted" to permanent from permanentExpired. -func TestNewPEBOnPromotionToPermanent(t *testing.T) { - pebs := []stack.PrimaryEndpointBehavior{ - stack.NeverPrimaryEndpoint, - stack.CanBePrimaryEndpoint, - stack.FirstPrimaryEndpoint, - } - - for _, pi := range pebs { - for _, ps := range pebs { - t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - // Add a permanent address with initial - // PrimaryEndpointBehavior (peb), pi. If pi is - // NeverPrimaryEndpoint, the address should not - // be returned by a call to GetMainNICAddress; - // else, it should. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - } - addr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) - } - if pi == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) - - } - } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatalf("NewSubnet failed:", err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Take a route through the address so its ref - // count gets incremented and does not actually - // get deleted when RemoveAddress is called - // below. This is because we want to test that a - // new peb is respected when an address gets - // "promoted" to permanent from a - // permanentExpired kind. - r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - defer r.Release() - if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed:", err) - } - - // - // At this point, the address should still be - // known by the NIC, but have its - // kind = permanentExpired. - // - - // Add some other address with peb set to - // FirstPrimaryEndpoint. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - - } - - // Add back the address we removed earlier and - // make sure the new peb was respected. - // (The address should just be promoted now). - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) - } - var primaryAddrs []tcpip.Address - for _, pa := range s.NICInfo()[1].ProtocolAddresses { - primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) - } - var expectedList []tcpip.Address - switch ps { - case stack.FirstPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x01", - "\x03", - } - case stack.CanBePrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - "\x01", - } - case stack.NeverPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - } - } - if !cmp.Equal(primaryAddrs, expectedList) { - t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) - } - - // Once we remove the other address, if the new - // peb, ps, was NeverPrimaryEndpoint, no address - // should be returned by a call to - // GetMainNICAddress; else, our original address - // should be returned. - if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed:", err) - } - addr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) - } - if ps == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) - - } - } else { - if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) - } - } - }) - } - } -} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go deleted file mode 100644 index 3b28b06d0..000000000 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ /dev/null @@ -1,354 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "math" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testPort = 4096 -) - -type testContext struct { - t *testing.T - linkEPs map[string]*channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createV6Endpoint(v6only bool) { - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } -} - -// newDualTestContextMultiNic creates the testing context and also linkEpNames -// named NICs. -func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEPs := make(map[string]*channel.Endpoint) - for i, linkEpName := range linkEpNames { - channelEP := channel.New(256, mtu, "") - nicID := tcpip.NICID(i + 1) - if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - linkEPs[linkEpName] = channelEP - - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress IPv4 failed: %v", err) - } - - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress IPv6 failed: %v", err) - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEPs: linkEPs, - } -} - -type headers struct { - srcPort uint16 - dstPort uint16 -} - -func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -func TestTransportDemuxerRegister(t *testing.T) { - for _, test := range []struct { - name string - proto tcpip.NetworkProtocolNumber - want *tcpip.Error - }{ - {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, - {"success", ipv4.ProtocolNumber, nil}, - } { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) - } - }) - } -} - -// TestReuseBindToDevice injects varied packets on input devices and checks that -// the distribution of packets received matches expectations. -func TestDistribution(t *testing.T) { - type endpointSockopts struct { - reuse int - bindToDevice string - } - for _, test := range []struct { - name string - // endpoints will received the inject packets. - endpoints []endpointSockopts - // wantedDistribution is the wanted ratio of packets received on each - // endpoint for each NIC on which packets are injected. - wantedDistributions map[string][]float64 - }{ - { - "BindPortReuse", - // 5 endpoints that all have reuse set. - []endpointSockopts{ - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - }, - map[string][]float64{ - // Injected packets on dev0 get distributed evenly. - "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, - }, - }, - { - "BindToDevice", - // 3 endpoints with various bindings. - []endpointSockopts{ - endpointSockopts{0, "dev0"}, - endpointSockopts{0, "dev1"}, - endpointSockopts{0, "dev2"}, - }, - map[string][]float64{ - // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": []float64{1, 0, 0}, - // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": []float64{0, 1, 0}, - // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": []float64{0, 0, 1}, - }, - }, - { - "ReuseAndBindToDevice", - // 6 endpoints with various bindings. - []endpointSockopts{ - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, ""}, - }, - map[string][]float64{ - // Injected packets on dev0 get distributed among endpoints bound to - // dev0. - "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, - // Injected packets on dev1 get distributed among endpoints bound to - // dev1 or unbound. - "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, - // Injected packets on dev999 go only to the unbound. - "dev999": []float64{0, 0, 0, 0, 0, 1}, - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - for device, wantedDistribution := range test.wantedDistributions { - t.Run(device, func(t *testing.T) { - var devices []string - for d := range test.wantedDistributions { - devices = append(devices, d) - } - c := newDualTestContextMultiNic(t, defaultMTU, devices) - defer c.cleanup() - - c.createV6Endpoint(false) - - eps := make(map[tcpip.Endpoint]int) - - pollChannel := make(chan tcpip.Endpoint) - for i, endpoint := range test.endpoints { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - eps[ep] = i - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(ep) - - defer ep.Close() - reusePortOption := tcpip.ReusePortOption(endpoint.reuse) - if err := ep.SetSockOpt(reusePortOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) - } - bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) - } - if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) - } - } - - npackets := 100000 - nports := 10000 - if got, want := len(test.endpoints), len(wantedDistribution); got != want { - t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) - } - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.sendV6Packet(payload, - &headers{ - srcPort: testPort + port, - dstPort: stackPort}, - device) - - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled by the same - // socket. - if want, got := ports[port], ep; want != got { - t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) - } - } - } - - // Check that a packet distribution is as expected. - for ep, i := range eps { - wantedRatio := wantedDistribution[i] - wantedRecv := wantedRatio * float64(npackets) - actualRecv := stats[ep] - actualRatio := float64(stats[ep]) / float64(npackets) - // The deviation is less than 10%. - if math.Abs(actualRatio-wantedRatio) > 0.05 { - t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) - } - } - }) - } - }) - } -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go deleted file mode 100644 index 748ce4ea5..000000000 --- a/pkg/tcpip/stack/transport_test.go +++ /dev/null @@ -1,629 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 -) - -// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't -// use it. -type fakeTransportEndpoint struct { - stack.TransportEndpointInfo - stack *stack.Stack - proto *fakeTransportProtocol - peerAddr tcpip.Address - route stack.Route - uniqueID uint64 - - // acceptQueue is non-nil iff bound. - acceptQueue []fakeTransportEndpoint -} - -func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { - return &f.TransportEndpointInfo -} - -func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { - return nil -} - -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} -} - -func (f *fakeTransportEndpoint) Close() { - f.route.Release() -} - -func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return mask -} - -func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return buffer.View{}, tcpip.ControlMessages{}, nil -} - -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - if len(f.route.RemoteAddress) == 0 { - return 0, nil, tcpip.ErrNoRoute - } - - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.FullPayload() - if err != nil { - return 0, nil, err - } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - Data: buffer.View(v).ToVectorisedView(), - }); err != nil { - return 0, nil, err - } - - return int64(len(v)), nil, nil -} - -func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - -// SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { - return -1, tcpip.ErrUnknownProtocolOption -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - } - return tcpip.ErrInvalidEndpointState -} - -// Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported -} - -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - f.peerAddr = addr.Addr - - // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return tcpip.ErrNoRoute - } - defer r.Release() - - // Try to register so that we can start receiving packets. - f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) - if err != nil { - return err - } - - f.route = r.Clone() - - return nil -} - -func (f *fakeTransportEndpoint) UniqueID() uint64 { - return f.uniqueID -} - -func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Reset() { -} - -func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { - return nil -} - -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - if len(f.acceptQueue) == 0 { - return nil, nil, nil - } - a := f.acceptQueue[0] - f.acceptQueue = f.acceptQueue[1:] - return &a, nil, nil -} - -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { - if err := f.stack.RegisterTransportEndpoint( - a.NIC, - []tcpip.NetworkProtocolNumber{fakeNetNumber}, - fakeTransNumber, - stack.TransportEndpointID{LocalAddress: a.Addr}, - f, - false, /* reuse */ - 0, /* bindtoDevice */ - ); err != nil { - return err - } - f.acceptQueue = []fakeTransportEndpoint{} - return nil -} - -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { - // Increment the number of received packets. - f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - stack: f.stack, - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) - } -} - -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { - // Increment the number of received control packets. - f.proto.controlCount++ -} - -func (f *fakeTransportEndpoint) State() uint32 { - return 0 -} - -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} - -func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { - return iptables.IPTables{}, nil -} - -func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} - -func (f *fakeTransportEndpoint) Wait() {} - -type fakeTransportGoodOption bool - -type fakeTransportBadOption bool - -type fakeTransportInvalidValueOption int - -type fakeTransportProtocolOptions struct { - good bool -} - -// fakeTransportProtocol is a transport-layer protocol descriptor. It -// aggregates the number of packets received via endpoints of this protocol. -type fakeTransportProtocol struct { - packetCount int - controlCount int - opts fakeTransportProtocolOptions -} - -func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { - return fakeTransNumber -} - -func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil -} - -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return nil, tcpip.ErrUnknownProtocol -} - -func (*fakeTransportProtocol) MinimumPacketSize() int { - return fakeTransHeaderLen -} - -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { - return 0, 0, nil -} - -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { - return true -} - -func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case fakeTransportGoodOption: - f.opts.good = bool(v) - return nil - case fakeTransportInvalidValueOption: - return tcpip.ErrInvalidOptionValue - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *fakeTransportGoodOption: - *v = fakeTransportGoodOption(f.opts.good) - return nil - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func fakeTransFactory() stack.TransportProtocol { - return &fakeTransportProtocol{} -} - -func TestTransportReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the packet. - buf := buffer.NewView(30) - - // Make sure packet with wrong protocol is not delivered. - buf[0] = 1 - buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[0] = 1 - buf[1] = 3 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet is delivered. - buf[0] = 1 - buf[1] = 2 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.packetCount != 1 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) - } -} - -func TestTransportControlReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the control packet. - buf := buffer.NewView(2*fakeNetHeaderLen + 30) - - // Outer packet contains the control protocol number. - buf[0] = 1 - buf[1] = 0xfe - buf[2] = uint8(fakeControlProtocol) - - // Make sure packet with wrong protocol is not delivered. - buf[fakeNetHeaderLen+0] = 0 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[fakeNetHeaderLen+0] = 3 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet is delivered. - buf[fakeNetHeaderLen+0] = 2 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - if fakeTrans.controlCount != 1 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) - } -} - -func TestTransportSend(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Create endpoint and bind it. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Create buffer that will hold the payload. - view := buffer.NewView(30) - _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("write failed: %v", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - if fakeNet.sendPacketCount[2] != 1 { - t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1) - } -} - -func TestTransportOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - - // Try an unsupported transport protocol. - if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.TransportProtocol) - }{ - {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { - t.Helper() - fakeTrans := p.(*fakeTransportProtocol) - if fakeTrans.opts.good != true { - t.Fatalf("fakeTrans.opts.good = false, want = true") - } - var v fakeTransportGoodOption - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) - } - - }}, - {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) - } - } -} - -func TestTransportForwarding(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, - }) - s.SetForwarding(true) - - // TODO(b/123449044): Change this to a channel NIC. - ep1 := loopback.New() - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - { - subnet0, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - }) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: req.ToVectorisedView(), - }) - - aep, _, err := ep.Accept() - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - var p channel.PacketInfo - select { - case p = <-ep2.C: - default: - t.Fatal("Response packet not forwarded") - } - - if dst := p.Pkt.Header.View()[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := p.Pkt.Header.View()[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go new file mode 100755 index 000000000..7083b7964 --- /dev/null +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -0,0 +1,75 @@ +// automatically generated by stateify. + +package tcpip + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *PacketBuffer) save(m state.Map) { + x.beforeSave() + m.Save("Data", &x.Data) + m.Save("DataOffset", &x.DataOffset) + m.Save("DataSize", &x.DataSize) + m.Save("Header", &x.Header) + m.Save("LinkHeader", &x.LinkHeader) + m.Save("NetworkHeader", &x.NetworkHeader) + m.Save("TransportHeader", &x.TransportHeader) +} + +func (x *PacketBuffer) afterLoad() {} +func (x *PacketBuffer) load(m state.Map) { + m.Load("Data", &x.Data) + m.Load("DataOffset", &x.DataOffset) + m.Load("DataSize", &x.DataSize) + m.Load("Header", &x.Header) + m.Load("LinkHeader", &x.LinkHeader) + m.Load("NetworkHeader", &x.NetworkHeader) + m.Load("TransportHeader", &x.TransportHeader) +} + +func (x *FullAddress) beforeSave() {} +func (x *FullAddress) save(m state.Map) { + x.beforeSave() + m.Save("NIC", &x.NIC) + m.Save("Addr", &x.Addr) + m.Save("Port", &x.Port) +} + +func (x *FullAddress) afterLoad() {} +func (x *FullAddress) load(m state.Map) { + m.Load("NIC", &x.NIC) + m.Load("Addr", &x.Addr) + m.Load("Port", &x.Port) +} + +func (x *ControlMessages) beforeSave() {} +func (x *ControlMessages) save(m state.Map) { + x.beforeSave() + m.Save("HasTimestamp", &x.HasTimestamp) + m.Save("Timestamp", &x.Timestamp) + m.Save("HasInq", &x.HasInq) + m.Save("Inq", &x.Inq) + m.Save("HasTOS", &x.HasTOS) + m.Save("TOS", &x.TOS) + m.Save("HasTClass", &x.HasTClass) + m.Save("TClass", &x.TClass) +} + +func (x *ControlMessages) afterLoad() {} +func (x *ControlMessages) load(m state.Map) { + m.Load("HasTimestamp", &x.HasTimestamp) + m.Load("Timestamp", &x.Timestamp) + m.Load("HasInq", &x.HasInq) + m.Load("Inq", &x.Inq) + m.Load("HasTOS", &x.HasTOS) + m.Load("TOS", &x.TOS) + m.Load("HasTClass", &x.HasTClass) + m.Load("TClass", &x.TClass) +} + +func init() { + state.Register("tcpip.PacketBuffer", (*PacketBuffer)(nil), state.Fns{Save: (*PacketBuffer).save, Load: (*PacketBuffer).load}) + state.Register("tcpip.FullAddress", (*FullAddress)(nil), state.Fns{Save: (*FullAddress).save, Load: (*FullAddress).load}) + state.Register("tcpip.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load}) +} diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go deleted file mode 100644 index 8c0aacffa..000000000 --- a/pkg/tcpip/tcpip_test.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import ( - "fmt" - "net" - "strings" - "testing" -) - -func TestSubnetContains(t *testing.T) { - tests := []struct { - s Address - m AddressMask - a Address - want bool - }{ - {"\xa0", "\xf0", "\x90", false}, - {"\xa0", "\xf0", "\xa0", true}, - {"\xa0", "\xf0", "\xa5", true}, - {"\xa0", "\xf0", "\xaf", true}, - {"\xa0", "\xf0", "\xb0", false}, - {"\xa0", "\xf0", "", false}, - {"\xa0", "\xf0", "\xa0\x00", false}, - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - s, err := NewSubnet(tt.s, tt.m) - if err != nil { - t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err) - continue - } - if got := s.Contains(tt.a); got != tt.want { - t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want) - } - } -} - -func TestSubnetBits(t *testing.T) { - tests := []struct { - a AddressMask - want1 int - want0 int - }{ - {"\x00", 0, 8}, - {"\x00\x00", 0, 16}, - {"\x36", 0, 8}, - {"\x5c", 0, 8}, - {"\x5c\x5c", 0, 16}, - {"\x5c\x36", 0, 16}, - {"\x36\x5c", 0, 16}, - {"\x36\x36", 0, 16}, - {"\xff", 8, 0}, - {"\xff\xff", 16, 0}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got1, got0 := s.Bits() - if got1 != tt.want1 || got0 != tt.want0 { - t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0) - } - } -} - -func TestSubnetPrefix(t *testing.T) { - tests := []struct { - a AddressMask - want int - }{ - {"\x00", 0}, - {"\x00\x00", 0}, - {"\x36", 0}, - {"\x86", 1}, - {"\xc5", 2}, - {"\xff\x00", 8}, - {"\xff\x36", 8}, - {"\xff\x8c", 9}, - {"\xff\xc8", 10}, - {"\xff", 8}, - {"\xff\xff", 16}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got := s.Prefix() - if got != tt.want { - t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want) - } - } -} - -func TestSubnetCreation(t *testing.T) { - tests := []struct { - a Address - m AddressMask - want error - }{ - {"\xa0", "\xf0", nil}, - {"\xa0\xa0", "\xf0", errSubnetLengthMismatch}, - {"\xaa", "\xf0", errSubnetAddressMasked}, - {"", "", nil}, - } - for _, tt := range tests { - if _, err := NewSubnet(tt.a, tt.m); err != tt.want { - t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want) - } - } -} - -func TestAddressString(t *testing.T) { - for _, want := range []string{ - // Taken from stdlib. - "2001:db8::123:12:1", - "2001:db8::1", - "2001:db8:0:1:0:1:0:1", - "2001:db8:1:0:1:0:1:0", - "2001::1:0:0:1", - "2001:db8:0:0:1::", - "2001:db8::1:0:0:1", - "2001:db8::a:b:c:d", - - // Leading zeros. - "::1", - // Trailing zeros. - "8::", - // No zeros. - "1:1:1:1:1:1:1:1", - // Longer sequence is after other zeros, but not at the end. - "1:0:0:1::1", - // Longer sequence is at the beginning, shorter sequence is at - // the end. - "::1:1:1:0:0", - // Longer sequence is not at the beginning, shorter sequence is - // at the end. - "1::1:1:0:0", - // Longer sequence is at the beginning, shorter sequence is not - // at the end. - "::1:1:0:0:1", - // Neither sequence is at an end, longer is after shorter. - "1:0:0:1::1", - // Shorter sequence is at the beginning, longer sequence is not - // at the end. - "0:0:1:1::1", - // Shorter sequence is at the beginning, longer sequence is at - // the end. - "0:0:1:1:1::", - // Short sequences at both ends, longer one in the middle. - "0:1:1::1:1:0", - // Short sequences at both ends, longer one in the middle. - "0:1::1:0:0", - // Short sequences at both ends, longer one in the middle. - "0:0:1::1:0", - // Longer sequence surrounded by shorter sequences, but none at - // the end. - "1:0:1::1:0:1", - } { - addr := Address(net.ParseIP(want)) - if got := addr.String(); got != want { - t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want) - } - } -} - -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - -func TestAddressWithPrefixSubnet(t *testing.T) { - tests := []struct { - addr Address - prefixLen int - subnetAddr Address - subnetMask AddressMask - }{ - {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, - {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, - {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, - {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, - } - for _, tt := range tests { - ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} - gotSubnet := ap.Subnet() - wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) - if err != nil { - t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) - continue - } - if gotSubnet != wantSubnet { - t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) - } - } -} diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/tcpip/time.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD deleted file mode 100644 index d8c5b5058..000000000 --- a/pkg/tcpip/transport/icmp/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_template_instance( - name = "icmp_packet_list", - out = "icmp_packet_list.go", - package = "icmp", - prefix = "icmpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*icmpPacket", - "Linker": "*icmpPacket", - }, -) - -go_library( - name = "icmp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "icmp_packet_list.go", - "protocol.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go new file mode 100755 index 000000000..1b35e5b4a --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go @@ -0,0 +1,173 @@ +package icmp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type icmpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type icmpPacketList struct { + head *icmpPacket + tail *icmpPacket +} + +// Reset resets list l to the empty state. +func (l *icmpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *icmpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *icmpPacketList) Front() *icmpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *icmpPacketList) Back() *icmpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *icmpPacketList) PushFront(e *icmpPacket) { + icmpPacketElementMapper{}.linkerFor(e).SetNext(l.head) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *icmpPacketList) PushBack(e *icmpPacket) { + icmpPacketElementMapper{}.linkerFor(e).SetNext(nil) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *icmpPacketList) PushBackList(m *icmpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) { + a := icmpPacketElementMapper{}.linkerFor(b).Next() + icmpPacketElementMapper{}.linkerFor(e).SetNext(a) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(b) + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) { + b := icmpPacketElementMapper{}.linkerFor(a).Prev() + icmpPacketElementMapper{}.linkerFor(e).SetNext(a) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(b) + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *icmpPacketList) Remove(e *icmpPacket) { + prev := icmpPacketElementMapper{}.linkerFor(e).Prev() + next := icmpPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + icmpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type icmpPacketEntry struct { + next *icmpPacket + prev *icmpPacket +} + +// Next returns the entry that follows e in the list. +func (e *icmpPacketEntry) Next() *icmpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *icmpPacketEntry) Prev() *icmpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *icmpPacketEntry) SetNext(elem *icmpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go new file mode 100755 index 000000000..bf2d76839 --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go @@ -0,0 +1,92 @@ +// automatically generated by stateify. + +package icmp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *icmpPacket) beforeSave() {} +func (x *icmpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("icmpPacketEntry", &x.icmpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("timestamp", &x.timestamp) +} + +func (x *icmpPacket) afterLoad() {} +func (x *icmpPacket) load(m state.Map) { + m.Load("icmpPacketEntry", &x.icmpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("timestamp", &x.timestamp) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("uniqueID", &x.uniqueID) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("state", &x.state) + m.Save("ttl", &x.ttl) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("uniqueID", &x.uniqueID) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("state", &x.state) + m.Load("ttl", &x.ttl) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *icmpPacketList) beforeSave() {} +func (x *icmpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *icmpPacketList) afterLoad() {} +func (x *icmpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *icmpPacketEntry) beforeSave() {} +func (x *icmpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *icmpPacketEntry) afterLoad() {} +func (x *icmpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("icmp.icmpPacket", (*icmpPacket)(nil), state.Fns{Save: (*icmpPacket).save, Load: (*icmpPacket).load}) + state.Register("icmp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("icmp.icmpPacketList", (*icmpPacketList)(nil), state.Fns{Save: (*icmpPacketList).save, Load: (*icmpPacketList).load}) + state.Register("icmp.icmpPacketEntry", (*icmpPacketEntry)(nil), state.Fns{Save: (*icmpPacketEntry).save, Load: (*icmpPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD deleted file mode 100644 index 44b58ff6b..000000000 --- a/pkg/tcpip/transport/packet/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_template_instance( - name = "packet_list", - out = "packet_list.go", - package = "packet", - prefix = "packet", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*packet", - "Linker": "*packet", - }, -) - -go_library( - name = "packet", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "packet_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0010b5e5f..0010b5e5f 100644..100755 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..9b88f17e4 100644..100755 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go new file mode 100755 index 000000000..0da0dfcb6 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_list.go @@ -0,0 +1,173 @@ +package packet + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type packetElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (packetElementMapper) linkerFor(elem *packet) *packet { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type packetList struct { + head *packet + tail *packet +} + +// Reset resets list l to the empty state. +func (l *packetList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *packetList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *packetList) Front() *packet { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *packetList) Back() *packet { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *packetList) PushFront(e *packet) { + packetElementMapper{}.linkerFor(e).SetNext(l.head) + packetElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + packetElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *packetList) PushBack(e *packet) { + packetElementMapper{}.linkerFor(e).SetNext(nil) + packetElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *packetList) PushBackList(m *packetList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(m.head) + packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *packetList) InsertAfter(b, e *packet) { + a := packetElementMapper{}.linkerFor(b).Next() + packetElementMapper{}.linkerFor(e).SetNext(a) + packetElementMapper{}.linkerFor(e).SetPrev(b) + packetElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + packetElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *packetList) InsertBefore(a, e *packet) { + b := packetElementMapper{}.linkerFor(a).Prev() + packetElementMapper{}.linkerFor(e).SetNext(a) + packetElementMapper{}.linkerFor(e).SetPrev(b) + packetElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + packetElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *packetList) Remove(e *packet) { + prev := packetElementMapper{}.linkerFor(e).Prev() + next := packetElementMapper{}.linkerFor(e).Next() + + if prev != nil { + packetElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + packetElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type packetEntry struct { + next *packet + prev *packet +} + +// Next returns the entry that follows e in the list. +func (e *packetEntry) Next() *packet { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *packetEntry) Prev() *packet { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *packetEntry) SetNext(elem *packet) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *packetEntry) SetPrev(elem *packet) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go new file mode 100755 index 000000000..7df5ec74c --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -0,0 +1,88 @@ +// automatically generated by stateify. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *packet) beforeSave() {} +func (x *packet) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("packetEntry", &x.packetEntry) + m.Save("timestampNS", &x.timestampNS) + m.Save("senderAddr", &x.senderAddr) +} + +func (x *packet) afterLoad() {} +func (x *packet) load(m state.Map) { + m.Load("packetEntry", &x.packetEntry) + m.Load("timestampNS", &x.timestampNS) + m.Load("senderAddr", &x.senderAddr) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("netProto", &x.netProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("cooked", &x.cooked) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("closed", &x.closed) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("netProto", &x.netProto) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("cooked", &x.cooked) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("closed", &x.closed) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *packetList) beforeSave() {} +func (x *packetList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *packetList) afterLoad() {} +func (x *packetList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *packetEntry) beforeSave() {} +func (x *packetEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *packetEntry) afterLoad() {} +func (x *packetEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("packet.packet", (*packet)(nil), state.Fns{Save: (*packet).save, Load: (*packet).load}) + state.Register("packet.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("packet.packetList", (*packetList)(nil), state.Fns{Save: (*packetList).save, Load: (*packetList).load}) + state.Register("packet.packetEntry", (*packetEntry)(nil), state.Fns{Save: (*packetEntry).save, Load: (*packetEntry).load}) +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD deleted file mode 100644 index 00991ac8e..000000000 --- a/pkg/tcpip/transport/raw/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_template_instance( - name = "raw_packet_list", - out = "raw_packet_list.go", - package = "raw", - prefix = "rawPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*rawPacket", - "Linker": "*rawPacket", - }, -) - -go_library( - name = "raw", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "protocol.go", - "raw_packet_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/packet", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go new file mode 100755 index 000000000..12edb4334 --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_packet_list.go @@ -0,0 +1,173 @@ +package raw + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type rawPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (rawPacketElementMapper) linkerFor(elem *rawPacket) *rawPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type rawPacketList struct { + head *rawPacket + tail *rawPacket +} + +// Reset resets list l to the empty state. +func (l *rawPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *rawPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *rawPacketList) Front() *rawPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *rawPacketList) Back() *rawPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *rawPacketList) PushFront(e *rawPacket) { + rawPacketElementMapper{}.linkerFor(e).SetNext(l.head) + rawPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *rawPacketList) PushBack(e *rawPacket) { + rawPacketElementMapper{}.linkerFor(e).SetNext(nil) + rawPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *rawPacketList) PushBackList(m *rawPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + rawPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *rawPacketList) InsertAfter(b, e *rawPacket) { + a := rawPacketElementMapper{}.linkerFor(b).Next() + rawPacketElementMapper{}.linkerFor(e).SetNext(a) + rawPacketElementMapper{}.linkerFor(e).SetPrev(b) + rawPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + rawPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *rawPacketList) InsertBefore(a, e *rawPacket) { + b := rawPacketElementMapper{}.linkerFor(a).Prev() + rawPacketElementMapper{}.linkerFor(e).SetNext(a) + rawPacketElementMapper{}.linkerFor(e).SetPrev(b) + rawPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + rawPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *rawPacketList) Remove(e *rawPacket) { + prev := rawPacketElementMapper{}.linkerFor(e).Prev() + next := rawPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + rawPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + rawPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type rawPacketEntry struct { + next *rawPacket + prev *rawPacket +} + +// Next returns the entry that follows e in the list. +func (e *rawPacketEntry) Next() *rawPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *rawPacketEntry) Prev() *rawPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *rawPacketEntry) SetNext(elem *rawPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *rawPacketEntry) SetPrev(elem *rawPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go new file mode 100755 index 000000000..3753ccc37 --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -0,0 +1,90 @@ +// automatically generated by stateify. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *rawPacket) beforeSave() {} +func (x *rawPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("rawPacketEntry", &x.rawPacketEntry) + m.Save("timestampNS", &x.timestampNS) + m.Save("senderAddr", &x.senderAddr) +} + +func (x *rawPacket) afterLoad() {} +func (x *rawPacket) load(m state.Map) { + m.Load("rawPacketEntry", &x.rawPacketEntry) + m.Load("timestampNS", &x.timestampNS) + m.Load("senderAddr", &x.senderAddr) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("associated", &x.associated) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("closed", &x.closed) + m.Save("connected", &x.connected) + m.Save("bound", &x.bound) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("associated", &x.associated) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("closed", &x.closed) + m.Load("connected", &x.connected) + m.Load("bound", &x.bound) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *rawPacketList) beforeSave() {} +func (x *rawPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *rawPacketList) afterLoad() {} +func (x *rawPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *rawPacketEntry) beforeSave() {} +func (x *rawPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *rawPacketEntry) afterLoad() {} +func (x *rawPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("raw.rawPacket", (*rawPacket)(nil), state.Fns{Save: (*rawPacket).save, Load: (*rawPacket).load}) + state.Register("raw.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("raw.rawPacketList", (*rawPacketList)(nil), state.Fns{Save: (*rawPacketList).save, Load: (*rawPacketList).load}) + state.Register("raw.rawPacketEntry", (*rawPacketEntry)(nil), state.Fns{Save: (*rawPacketEntry).save, Load: (*rawPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD deleted file mode 100644 index 455a1c098..000000000 --- a/pkg/tcpip/transport/tcp/BUILD +++ /dev/null @@ -1,94 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_template_instance( - name = "tcp_segment_list", - out = "tcp_segment_list.go", - package = "tcp", - prefix = "segment", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*segment", - "Linker": "*segment", - }, -) - -go_library( - name = "tcp", - srcs = [ - "accept.go", - "connect.go", - "cubic.go", - "cubic_state.go", - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "rcv.go", - "reno.go", - "sack.go", - "sack_scoreboard.go", - "segment.go", - "segment_heap.go", - "segment_queue.go", - "segment_state.go", - "snd.go", - "snd_state.go", - "tcp_segment_list.go", - "timer.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tmutex", - "//pkg/waiter", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "tcp_test", - size = "medium", - srcs = [ - "dual_stack_test.go", - "sack_scoreboard_test.go", - "tcp_noracedetector_test.go", - "tcp_sack_test.go", - "tcp_test.go", - "tcp_timestamp_test.go", - ], - # FIXME(b/68809571) - tags = ["flaky"], - deps = [ - ":tcp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp/testing/context", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go deleted file mode 100644 index dfaa4a559..000000000 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ /dev/null @@ -1,654 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestV4MappedConnectOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Start connection attempt, it must fail. - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrNoRoute { - t.Fatalf("Unexpected return value from Connect: %v", err) - } -} - -func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv4(t, b, synCheckers...) - - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - )) - checker.IPv4(t, c.GetPacket(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV4MappedConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt to IPv6 address. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetV6Packet() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv6(t, b, synCheckers...) - - tcp := header.TCP(header.IPv6(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - )) - checker.IPv6(t, c.GetV6Packet(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV6Connect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to local address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV4RefuseOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), - ), - ) -} - -func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), - ), - ) -} - -func testV4Accept(t *testing.T, c *context.Context) { - c.SetGSOEnabled(true) - defer c.SetGSOEnabled(false) - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv4(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Make sure we get the same error when calling the original ep and the - // new one. This validates that v4-mapped endpoints are still able to - // query the V6Only flag, whereas pure v4 endpoints are not. - var v tcpip.V6OnlyOption - expected := c.EP.GetSockOpt(&v) - if err := nep.GetSockOpt(&v); err != expected { - t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestAddr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) - } - - data := "Don't panic" - nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestV4AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV6AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv6(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Make sure we can still query the v6 only status of the new endpoint, - // that is, that it is in fact a v6 socket. - var v tcpip.V6OnlyOption - if err := nep.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) - } -} - -func TestV4AcceptOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func testV4ListenClose(t *testing.T, c *context.Context) { - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 - const n = uint16(32) - - // Start listening. - if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - irs := seqnum.Value(789) - for i := uint16(0); i < n; i++ { - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + i, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } - - // Each of these ACK's will cause a syn-cookie based connection to be - // accepted and delivered to the listening endpoint. - for i := uint16(0); i < n; i++ { - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - } - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - nep.Close() - c.EP.Close() -} - -func TestV4ListenCloseOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4ListenClose(t, c) -} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go deleted file mode 100644 index b4e5ba0df..000000000 --- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" -) - -const smss = 1500 - -func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard { - s := tcp.NewSACKScoreboard(smss, iss) - for _, blk := range blocks { - s.Insert(blk) - } - return s -} - -func TestSACKScoreboardIsSACKED(t *testing.T) { - type blockTest struct { - block header.SACKBlock - sacked bool - } - testCases := []struct { - comment string - scoreboardBlocks []header.SACKBlock - blockTests []blockTest - iss seqnum.Value - }{ - { - "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", - []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, - []blockTest{ - {header.SACKBlock{15, 21}, true}, - {header.SACKBlock{200, 201}, false}, - {header.SACKBlock{50, 51}, false}, - {header.SACKBlock{53, 120}, true}, - }, - 0, - }, - { - "Test disjoint SACKBlocks", - []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, - []blockTest{ - {header.SACKBlock{2288624809, 2288810057}, true}, - {header.SACKBlock{2288811477, 2288838565}, true}, - {header.SACKBlock{2288810057, 2288811477}, false}, - }, - 2288624809, - }, - { - "Test sequence number wrap around", - []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, - []blockTest{ - {header.SACKBlock{4294254144, 4294254145}, true}, - {header.SACKBlock{4294254143, 4294254144}, false}, - {header.SACKBlock{4294254144, 1}, true}, - {header.SACKBlock{225652, 5350509}, false}, - {header.SACKBlock{5340409, 5350509}, true}, - {header.SACKBlock{5350509, 5350609}, false}, - }, - 4294254144, - }, - { - "Test disjoint SACKBlocks out of order", - []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, - []blockTest{ - {header.SACKBlock{827426028, 827428867}, true}, - {header.SACKBlock{827450168, 827450275}, false}, - }, - 827426000, - }, - } - for _, tc := range testCases { - sb := initScoreboard(tc.scoreboardBlocks, tc.iss) - for _, blkTest := range tc.blockTests { - if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { - t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) - } - } - } -} - -func TestSACKScoreboardIsRangeLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - block header.SACKBlock - lost bool - }{ - // Block not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number covered by this block. - {block: header.SACKBlock{0, 1}, lost: true}, - - // These blocks have all been SACKed and should not be - // considered lost. - {block: header.SACKBlock{1, 2}, lost: false}, - {block: header.SACKBlock{25, 26}, lost: false}, - {block: header.SACKBlock{1, 45}, lost: false}, - - // Same as the first case above. - {block: header.SACKBlock{50, 51}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{119, 120}, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {block: header.SACKBlock{120, 121}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{125, 126}, lost: false}, - - // This block has not been SACKed and there are nDupAckThreshold - // number of SACKed blocks after it. - {block: header.SACKBlock{141, 145}, lost: true}, - - // This block has not been SACKed and there are less than - // nDupAckThreshold SACKed sequences after it. - {block: header.SACKBlock{151, 152}, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { - t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) - } - } -} - -func TestSACKScoreboardIsLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - seq seqnum.Value - lost bool - }{ - // Sequence number not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number. - {seq: 0, lost: true}, - - // These sequence numbers have all been SACKed and should not be - // considered lost. - {seq: 1, lost: false}, - {seq: 25, lost: false}, - {seq: 45, lost: false}, - - // Same as first case above. - {seq: 50, lost: true}, - - // This block has been SACKed and should not be considered lost. - {seq: 119, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {seq: 120, lost: true}, - - // This sequence number has been SACKed and should not be - // considered lost. - {seq: 125, lost: false}, - - // This sequence number has not been SACKed and there are - // nDupAckThreshold number of SACKed blocks after it. - {seq: 141, lost: true}, - - // This sequence number has not been SACKed and there are less - // than nDupAckThreshold SACKed sequences after it. - {seq: 151, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsLost(tc.seq); got != want { - t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) - } - } -} - -func TestSACKScoreboardDelete(t *testing.T) { - blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} - s := initScoreboard(blocks, 4294254143) - s.Delete(5340408) - if s.Empty() { - t.Fatalf("s.Empty() = true, want false") - } - if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } - s.Delete(5340410) - if s.Empty() { - t.Fatal("s.Empty() = true, want false") - } - newSB := header.SACKBlock{5340410, 5350509} - if !s.IsSACKED(newSB) { - t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) - } - s.Delete(5350509) - lastOctet := header.SACKBlock{5350508, 5350509} - if s.IsSACKED(lastOctet) { - t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) - } - - s.Delete(5350510) - if !s.Empty() { - t.Fatal("s.Empty() = false, want true") - } - if got, want := s.Sacked(), seqnum.Size(0); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go deleted file mode 100644 index 782d7b42c..000000000 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ /dev/null @@ -1,527 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// These tests are flaky when run under the go race detector due to some -// iterations taking long enough that the retransmit timer can kick in causing -// the congestion window measurements to fail due to extra packets etc. -// -// +build !race - -package tcp_test - -import ( - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" -) - -func TestFastRecovery(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - for i := 0; i < 3; i++ { - c.SendAck(790, rtxOffset) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) - } - - // Now send 7 mode duplicate acks. Each of these should cause a window - // inflation by 1 and cause the sender to send an extra packet. - for i := 0; i < 7; i++ { - c.SendAck(790, rtxOffset) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the retransmit due to partial ack. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) - } - - // Receive the 10 extra packets that should have been released due to - // the congestion window inflation in recovery. - for i := 0; i < 10; i++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // A partial ACK during recovery should reduce congestion window by the - // number acked. Since we had "expected" packets outstanding before sending - // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered - // fast recovery). Which means the sender should not send any more packets - // till we ack this one. - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", - 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 10 - // packets outstanding. - // - // NOTE: Technically netstack is incorrect in that we adjust the cwnd on - // the same segment that takes us out of recovery. But because of that - // the actual cwnd at exit of recovery will be expected/2 + 1 as we - // acked a cwnd worth of packets which will increase the cwnd further by - // 1 in congestion avoidance. - // - // Now in the first iteration since there are 10 packets outstanding. - // We would expect to get expected/2 +1 - 10 packets. But subsequent - // iterations will send us expected/2 + 1 + 1 (per iteration). - expected = expected/2 + 1 - 10 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 10 - } - expected++ - } -} - -func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // Double the number of expected packets for the next iteration. - expected *= 2 - } -} - -func TestCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd/2. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected/2. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 (which "consumes" expected/2-1 of the - // acknowledgements), then the congestion avoidance part will consume - // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack - // remains in the "ack count" (which will cause cwnd to be incremented - // once it reaches cwnd acks). - // - // So we're straight into congestion avoidance with cwnd set to - // expected/2 + 1. - // - // Check that packets trains of cwnd packets are sent, and that cwnd is - // incremented by 1 after we acknowledge each packet. - expected = expected/2 + 1 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - expected++ - } -} - -// cubicCwnd returns an estimate of a cubic window given the -// originalCwnd, wMax, last congestion event time and sRTT. -func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { - cwnd := float64(origCwnd) - // We wait 50ms between each iteration so sRTT as computed by cubic - // should be close to 50ms. - elapsed := (time.Since(congEventTime) + sRTT).Seconds() - k := math.Cbrt(float64(wMax) * 0.3 / 0.7) - wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) - cwnd += (wtRTT - cwnd) / cwnd - return int(cwnd) -} - -func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - enableCUBIC(t, c) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd * 0.7. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all pending data. - c.SendAck(790, bytesRead) - - // Store away the time we sent the ACK and assuming a 200ms RTO - // we estimate that the sender will have an RTO 200ms from now - // and go back into slow start. - packetDropTime := time.Now().Add(200 * time.Millisecond) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 essentially putting the connection - // straight into congestion avoidance. - wMax := expected - // Lower expected as per cubic spec after a congestion event. - expected = int(float64(expected) * 0.7) - cwnd := expected - for i := 0; i < iterations; i++ { - // Cubic grows window independent of ACKs. Cubic Window growth - // is a function of time elapsed since last congestion event. - // As a result the congestion window does not grow - // deterministically in response to ACKs. - // - // We need to roughly estimate what the cwnd of the sender is - // based on when we sent the dupacks. - cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) - - packetsExpected := cwnd - for j := 0; j < packetsExpected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - t.Logf("expected packets received, next trying to receive any extra packets that may come") - - // If our estimate was correct there should be no more pending packets. - // We attempt to read a packet a few times with a short sleep in between - // to ensure that we don't see the sender send any unexpected packets. - unexpectedPackets := 0 - for { - gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) - if !gotPacket { - break - } - bytesRead += maxPayload - unexpectedPackets++ - time.Sleep(1 * time.Millisecond) - } - if unexpectedPackets != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - } -} - -func TestRetransmit(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 7 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in two shots. Packets will only be written at the - // MTU size though. - half := data[:len(data)/2] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - half = data[len(data)/2:] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Wait for a timeout and retransmit. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) - } - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the remaining data, making sure that acknowledged data is not - // retransmitted. - for offset := rtxOffset; offset < len(data); offset += maxPayload { - c.ReceiveAndCheckPacket(data, offset, maxPayload) - c.SendAck(790, offset+maxPayload) - } - - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) -} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go deleted file mode 100644 index afea124ec..000000000 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ /dev/null @@ -1,572 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "fmt" - "log" - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" -) - -// createConnectedWithSACKPermittedOption creates and connects c.ep with the -// SACKPermitted option enabled if the stack in the context has the SACK support -// enabled. -func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) -} - -// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS -// option enabled if the stack in the context has SACK and TS enabled. -func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) -} - -func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { - t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err) - } -} - -// TestSackPermittedConnect establishes a connection with the SACK option -// enabled. -func TestSackPermittedConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - rep := createConnectedWithSACKPermittedOption(c) - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // Restore the saved sequence number so that the - // VerifyXXX calls use the right sequence number for - // checking ACK numbers. - rep.NextSeqNum = savedSeqNum - if sackEnabled { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackDisabledConnect establishes a connection with the SACK option -// disabled and verifies that no SACKs are sent for out of order segments. -func TestSackDisabledConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) - - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older sequence number and - // no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackPermittedAccept accepts and establishes a connection with the -// SACKPermitted option enabled if the connection request specifies the -// SACKPermitted option. In case of SYN cookies SACK should be disabled as we -// don't encode the SACK information in the cookie. -func TestSackPermittedAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - sackPermitted bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number. - rep.NextSeqNum = savedSeqNum - if sackEnabled && tc.sackPermitted { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -// TestSackDisabledAccept accepts and establishes a connection with -// the SACKPermitted option disabled and verifies that no SACKs are -// sent for out of order packets. -func TestSackDisabledAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number and no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -func TestUpdateSACKBlocks(t *testing.T) { - testCases := []struct { - segStart seqnum.Value - segEnd seqnum.Value - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - updated []header.SACKBlock - }{ - // Trivial cases where current SACK block list is empty and we - // have an out of order delivery. - {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, - {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, - {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, - - // Cases where current SACK block list is not empty and we have - // an out of order delivery. Tests that the updated SACK block - // list has the first block as the one that contains the new - // SACK block representing the segment that was just delivered. - {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, - - // Ensure that we only retain header.MaxSACKBlocks and drop the - // oldest one if adding a new block exceeds - // header.MaxSACKBlocks. - {24, 30, 9, - []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, - []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, - - // Cases where segment extends an existing SACK block. - {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, - {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, - - // Cases where segment contains rcvNxt. - {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, - } - - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { - t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) - } - - } -} - -func TestTrimSackBlockList(t *testing.T) { - testCases := []struct { - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - trimmed []header.SACKBlock - }{ - // Simple cases where we trim whole entries. - {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, - {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, - {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, - {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - // Cases where we need to update a block. - {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, - {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, - {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, - {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - } - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.TrimSACKBlockList(&sack, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { - t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) - } - } -} - -func TestSACKRecovery(t *testing.T) { - const maxPayload = 10 - // See: tcp.makeOptions for why tsOptionSize is set to 12 here. - const tsOptionSize = 12 - // Enabling SACK means the payload size is reduced to account - // for the extra space required for the TCP options. - // - // We increase the MTU by 40 bytes to account for SACK and Timestamp - // options. - const maxTCPOptionSize = 40 - - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) - defer c.Cleanup() - - c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { - // We use log.Printf instead of t.Logf here because this probe - // can fire even when the test function has finished. This is - // because closing the endpoint in cleanup() does not mean the - // actual worker loop terminates immediately as it still has to - // do a full TCP shutdown. But this test can finish running - // before the shutdown is done. Using t.Logf in such a case - // causes the test to panic due to logging after test finished. - log.Printf("state: %+v\n", s) - }) - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end := start.Add(10) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want) - } - } - - // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause - // window inflation and sending of packets is completely handled by the - // SACK Recovery algorithm. We should see no packets being released, as - // the cwnd at this point after entering recovery should be half of the - // outstanding number of packets in flight. - for i := 0; i < 7; i++ { - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. This along with the 10 sacked - // segments above should reduce the outstanding below the current - // congestion window allowing the sender to transmit data. - rtxOffset = bytesRead - expected*maxPayload/2 - - // Now send a partial ACK w/ a SACK block that indicates that the next 3 - // segments are lost and we have received 6 segments after the lost - // segments. This should cause the sender to immediately transmit all 3 - // segments in response to this ACK unlike in FastRecovery where only 1 - // segment is retransmitted per ACK. - start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end = start.Add(60) - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - - // At this point, we acked expected/2 packets and we SACKED 6 packets and - // 3 segments were considered lost due to the SACK block we sent. - // - // So total packets outstanding can be calculated as follows after 7 - // iterations of slow start -> 10/20/40/80/160/320/640. So expected - // should be 640 at start, then we went to recover at which point the - // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the - // network). - // Outstanding at this point after acking half the window - // (320 packets) will be: - // outstanding = 640-320-6(due to SACK block)-3 = 311 - // - // The last 3 is due to the fact that the first 3 packets after - // rtxOffset will be considered lost due to the SACK blocks sent. - // Receive the retransmit due to partial ack. - - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - // Receive the 2 extra packets that should have been retransmitted as - // those should be considered lost and immediately retransmitted based - // on the SACK information in the previous ACK sent above. - for i := 0; i < 2; i++ { - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) - } - - // Now we should get 9 more new unsent packets as the cwnd is 323 and - // outstanding is 311. - for i := 0; i < 9; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) - } - - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 9 - // packets outstanding. - // - // Now in the first iteration since there are 9 packets outstanding. - // We would expect to get expected/2 - 9 packets. But subsequent - // iterations will send us expected/2 + 1 (per iteration). - expected = expected/2 - 9 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 9 - } - expected++ - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go new file mode 100755 index 000000000..029f98a11 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,173 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *segmentList) Back() *segment { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *segmentList) PushFront(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(l.head) + segmentElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *segmentList) PushBack(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(nil) + segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *segmentList) InsertAfter(b, e *segment) { + a := segmentElementMapper{}.linkerFor(b).Next() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *segmentList) InsertBefore(a, e *segment) { + b := segmentElementMapper{}.linkerFor(a).Prev() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *segmentList) Remove(e *segment) { + prev := segmentElementMapper{}.linkerFor(e).Prev() + next := segmentElementMapper{}.linkerFor(e).Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go new file mode 100755 index 000000000..cc3f5c931 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -0,0 +1,493 @@ +// automatically generated by stateify. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *cubicState) beforeSave() {} +func (x *cubicState) save(m state.Map) { + x.beforeSave() + var t unixTime = x.saveT() + m.SaveValue("t", t) + m.Save("wLastMax", &x.wLastMax) + m.Save("wMax", &x.wMax) + m.Save("numCongestionEvents", &x.numCongestionEvents) + m.Save("c", &x.c) + m.Save("k", &x.k) + m.Save("beta", &x.beta) + m.Save("wC", &x.wC) + m.Save("wEst", &x.wEst) + m.Save("s", &x.s) +} + +func (x *cubicState) afterLoad() {} +func (x *cubicState) load(m state.Map) { + m.Load("wLastMax", &x.wLastMax) + m.Load("wMax", &x.wMax) + m.Load("numCongestionEvents", &x.numCongestionEvents) + m.Load("c", &x.c) + m.Load("k", &x.k) + m.Load("beta", &x.beta) + m.Load("wC", &x.wC) + m.Load("wEst", &x.wEst) + m.Load("s", &x.s) + m.LoadValue("t", new(unixTime), func(y interface{}) { x.loadT(y.(unixTime)) }) +} + +func (x *SACKInfo) beforeSave() {} +func (x *SACKInfo) save(m state.Map) { + x.beforeSave() + m.Save("Blocks", &x.Blocks) + m.Save("NumBlocks", &x.NumBlocks) +} + +func (x *SACKInfo) afterLoad() {} +func (x *SACKInfo) load(m state.Map) { + m.Load("Blocks", &x.Blocks) + m.Load("NumBlocks", &x.NumBlocks) +} + +func (x *rcvBufAutoTuneParams) beforeSave() {} +func (x *rcvBufAutoTuneParams) save(m state.Map) { + x.beforeSave() + var measureTime unixTime = x.saveMeasureTime() + m.SaveValue("measureTime", measureTime) + var rttMeasureTime unixTime = x.saveRttMeasureTime() + m.SaveValue("rttMeasureTime", rttMeasureTime) + m.Save("copied", &x.copied) + m.Save("prevCopied", &x.prevCopied) + m.Save("rtt", &x.rtt) + m.Save("rttMeasureSeqNumber", &x.rttMeasureSeqNumber) + m.Save("disabled", &x.disabled) +} + +func (x *rcvBufAutoTuneParams) afterLoad() {} +func (x *rcvBufAutoTuneParams) load(m state.Map) { + m.Load("copied", &x.copied) + m.Load("prevCopied", &x.prevCopied) + m.Load("rtt", &x.rtt) + m.Load("rttMeasureSeqNumber", &x.rttMeasureSeqNumber) + m.Load("disabled", &x.disabled) + m.LoadValue("measureTime", new(unixTime), func(y interface{}) { x.loadMeasureTime(y.(unixTime)) }) + m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) }) +} + +func (x *EndpointInfo) beforeSave() {} +func (x *EndpointInfo) save(m state.Map) { + x.beforeSave() + var HardError string = x.saveHardError() + m.SaveValue("HardError", HardError) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) +} + +func (x *EndpointInfo) afterLoad() {} +func (x *EndpointInfo) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.LoadValue("HardError", new(string), func(y interface{}) { x.loadHardError(y.(string)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var lastError string = x.saveLastError() + m.SaveValue("lastError", lastError) + var state EndpointState = x.saveState() + m.SaveValue("state", state) + var acceptedChan []*endpoint = x.saveAcceptedChan() + m.SaveValue("acceptedChan", acceptedChan) + m.Save("EndpointInfo", &x.EndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("uniqueID", &x.uniqueID) + m.Save("rcvList", &x.rcvList) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvBufUsed", &x.rcvBufUsed) + m.Save("rcvAutoParams", &x.rcvAutoParams) + m.Save("zeroWindow", &x.zeroWindow) + m.Save("isRegistered", &x.isRegistered) + m.Save("ttl", &x.ttl) + m.Save("v6only", &x.v6only) + m.Save("isConnectNotified", &x.isConnectNotified) + m.Save("broadcast", &x.broadcast) + m.Save("boundBindToDevice", &x.boundBindToDevice) + m.Save("boundPortFlags", &x.boundPortFlags) + m.Save("workerRunning", &x.workerRunning) + m.Save("workerCleanup", &x.workerCleanup) + m.Save("sendTSOk", &x.sendTSOk) + m.Save("recentTS", &x.recentTS) + m.Save("tsOffset", &x.tsOffset) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("sackPermitted", &x.sackPermitted) + m.Save("sack", &x.sack) + m.Save("reusePort", &x.reusePort) + m.Save("bindToDevice", &x.bindToDevice) + m.Save("delay", &x.delay) + m.Save("cork", &x.cork) + m.Save("scoreboard", &x.scoreboard) + m.Save("reuseAddr", &x.reuseAddr) + m.Save("slowAck", &x.slowAck) + m.Save("segmentQueue", &x.segmentQueue) + m.Save("synRcvdCount", &x.synRcvdCount) + m.Save("userMSS", &x.userMSS) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("sndBufUsed", &x.sndBufUsed) + m.Save("sndClosed", &x.sndClosed) + m.Save("sndBufInQueue", &x.sndBufInQueue) + m.Save("sndQueue", &x.sndQueue) + m.Save("cc", &x.cc) + m.Save("packetTooBigCount", &x.packetTooBigCount) + m.Save("sndMTU", &x.sndMTU) + m.Save("keepalive", &x.keepalive) + m.Save("rcv", &x.rcv) + m.Save("snd", &x.snd) + m.Save("connectingAddress", &x.connectingAddress) + m.Save("amss", &x.amss) + m.Save("sendTOS", &x.sendTOS) + m.Save("gso", &x.gso) + m.Save("tcpLingerTimeout", &x.tcpLingerTimeout) + m.Save("closed", &x.closed) +} + +func (x *endpoint) load(m state.Map) { + m.Load("EndpointInfo", &x.EndpointInfo) + m.LoadWait("waiterQueue", &x.waiterQueue) + m.Load("uniqueID", &x.uniqueID) + m.LoadWait("rcvList", &x.rcvList) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvBufUsed", &x.rcvBufUsed) + m.Load("rcvAutoParams", &x.rcvAutoParams) + m.Load("zeroWindow", &x.zeroWindow) + m.Load("isRegistered", &x.isRegistered) + m.Load("ttl", &x.ttl) + m.Load("v6only", &x.v6only) + m.Load("isConnectNotified", &x.isConnectNotified) + m.Load("broadcast", &x.broadcast) + m.Load("boundBindToDevice", &x.boundBindToDevice) + m.Load("boundPortFlags", &x.boundPortFlags) + m.Load("workerRunning", &x.workerRunning) + m.Load("workerCleanup", &x.workerCleanup) + m.Load("sendTSOk", &x.sendTSOk) + m.Load("recentTS", &x.recentTS) + m.Load("tsOffset", &x.tsOffset) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("sackPermitted", &x.sackPermitted) + m.Load("sack", &x.sack) + m.Load("reusePort", &x.reusePort) + m.Load("bindToDevice", &x.bindToDevice) + m.Load("delay", &x.delay) + m.Load("cork", &x.cork) + m.Load("scoreboard", &x.scoreboard) + m.Load("reuseAddr", &x.reuseAddr) + m.Load("slowAck", &x.slowAck) + m.LoadWait("segmentQueue", &x.segmentQueue) + m.Load("synRcvdCount", &x.synRcvdCount) + m.Load("userMSS", &x.userMSS) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("sndBufUsed", &x.sndBufUsed) + m.Load("sndClosed", &x.sndClosed) + m.Load("sndBufInQueue", &x.sndBufInQueue) + m.LoadWait("sndQueue", &x.sndQueue) + m.Load("cc", &x.cc) + m.Load("packetTooBigCount", &x.packetTooBigCount) + m.Load("sndMTU", &x.sndMTU) + m.Load("keepalive", &x.keepalive) + m.LoadWait("rcv", &x.rcv) + m.LoadWait("snd", &x.snd) + m.Load("connectingAddress", &x.connectingAddress) + m.Load("amss", &x.amss) + m.Load("sendTOS", &x.sendTOS) + m.Load("gso", &x.gso) + m.Load("tcpLingerTimeout", &x.tcpLingerTimeout) + m.Load("closed", &x.closed) + m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) }) + m.LoadValue("state", new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) }) + m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *keepalive) beforeSave() {} +func (x *keepalive) save(m state.Map) { + x.beforeSave() + m.Save("enabled", &x.enabled) + m.Save("idle", &x.idle) + m.Save("interval", &x.interval) + m.Save("count", &x.count) + m.Save("unacked", &x.unacked) +} + +func (x *keepalive) afterLoad() {} +func (x *keepalive) load(m state.Map) { + m.Load("enabled", &x.enabled) + m.Load("idle", &x.idle) + m.Load("interval", &x.interval) + m.Load("count", &x.count) + m.Load("unacked", &x.unacked) +} + +func (x *receiver) beforeSave() {} +func (x *receiver) save(m state.Map) { + x.beforeSave() + m.Save("ep", &x.ep) + m.Save("rcvNxt", &x.rcvNxt) + m.Save("rcvAcc", &x.rcvAcc) + m.Save("rcvWnd", &x.rcvWnd) + m.Save("rcvWndScale", &x.rcvWndScale) + m.Save("closed", &x.closed) + m.Save("pendingRcvdSegments", &x.pendingRcvdSegments) + m.Save("pendingBufUsed", &x.pendingBufUsed) + m.Save("pendingBufSize", &x.pendingBufSize) +} + +func (x *receiver) afterLoad() {} +func (x *receiver) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("rcvNxt", &x.rcvNxt) + m.Load("rcvAcc", &x.rcvAcc) + m.Load("rcvWnd", &x.rcvWnd) + m.Load("rcvWndScale", &x.rcvWndScale) + m.Load("closed", &x.closed) + m.Load("pendingRcvdSegments", &x.pendingRcvdSegments) + m.Load("pendingBufUsed", &x.pendingBufUsed) + m.Load("pendingBufSize", &x.pendingBufSize) +} + +func (x *renoState) beforeSave() {} +func (x *renoState) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *renoState) afterLoad() {} +func (x *renoState) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *SACKScoreboard) beforeSave() {} +func (x *SACKScoreboard) save(m state.Map) { + x.beforeSave() + m.Save("smss", &x.smss) + m.Save("maxSACKED", &x.maxSACKED) +} + +func (x *SACKScoreboard) afterLoad() {} +func (x *SACKScoreboard) load(m state.Map) { + m.Load("smss", &x.smss) + m.Load("maxSACKED", &x.maxSACKED) +} + +func (x *segment) beforeSave() {} +func (x *segment) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + var options []byte = x.saveOptions() + m.SaveValue("options", options) + var rcvdTime unixTime = x.saveRcvdTime() + m.SaveValue("rcvdTime", rcvdTime) + var xmitTime unixTime = x.saveXmitTime() + m.SaveValue("xmitTime", xmitTime) + m.Save("segmentEntry", &x.segmentEntry) + m.Save("refCnt", &x.refCnt) + m.Save("viewToDeliver", &x.viewToDeliver) + m.Save("sequenceNumber", &x.sequenceNumber) + m.Save("ackNumber", &x.ackNumber) + m.Save("flags", &x.flags) + m.Save("window", &x.window) + m.Save("csum", &x.csum) + m.Save("csumValid", &x.csumValid) + m.Save("parsedOptions", &x.parsedOptions) + m.Save("hasNewSACKInfo", &x.hasNewSACKInfo) +} + +func (x *segment) afterLoad() {} +func (x *segment) load(m state.Map) { + m.Load("segmentEntry", &x.segmentEntry) + m.Load("refCnt", &x.refCnt) + m.Load("viewToDeliver", &x.viewToDeliver) + m.Load("sequenceNumber", &x.sequenceNumber) + m.Load("ackNumber", &x.ackNumber) + m.Load("flags", &x.flags) + m.Load("window", &x.window) + m.Load("csum", &x.csum) + m.Load("csumValid", &x.csumValid) + m.Load("parsedOptions", &x.parsedOptions) + m.Load("hasNewSACKInfo", &x.hasNewSACKInfo) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) + m.LoadValue("options", new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) }) + m.LoadValue("rcvdTime", new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) }) + m.LoadValue("xmitTime", new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) }) +} + +func (x *segmentQueue) beforeSave() {} +func (x *segmentQueue) save(m state.Map) { + x.beforeSave() + m.Save("list", &x.list) + m.Save("limit", &x.limit) + m.Save("used", &x.used) +} + +func (x *segmentQueue) afterLoad() {} +func (x *segmentQueue) load(m state.Map) { + m.LoadWait("list", &x.list) + m.Load("limit", &x.limit) + m.Load("used", &x.used) +} + +func (x *sender) beforeSave() {} +func (x *sender) save(m state.Map) { + x.beforeSave() + var lastSendTime unixTime = x.saveLastSendTime() + m.SaveValue("lastSendTime", lastSendTime) + var rttMeasureTime unixTime = x.saveRttMeasureTime() + m.SaveValue("rttMeasureTime", rttMeasureTime) + m.Save("ep", &x.ep) + m.Save("dupAckCount", &x.dupAckCount) + m.Save("fr", &x.fr) + m.Save("sndCwnd", &x.sndCwnd) + m.Save("sndSsthresh", &x.sndSsthresh) + m.Save("sndCAAckCount", &x.sndCAAckCount) + m.Save("outstanding", &x.outstanding) + m.Save("sndWnd", &x.sndWnd) + m.Save("sndUna", &x.sndUna) + m.Save("sndNxt", &x.sndNxt) + m.Save("sndNxtList", &x.sndNxtList) + m.Save("rttMeasureSeqNum", &x.rttMeasureSeqNum) + m.Save("closed", &x.closed) + m.Save("writeNext", &x.writeNext) + m.Save("writeList", &x.writeList) + m.Save("rtt", &x.rtt) + m.Save("rto", &x.rto) + m.Save("maxPayloadSize", &x.maxPayloadSize) + m.Save("gso", &x.gso) + m.Save("sndWndScale", &x.sndWndScale) + m.Save("maxSentAck", &x.maxSentAck) + m.Save("state", &x.state) + m.Save("cc", &x.cc) +} + +func (x *sender) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("dupAckCount", &x.dupAckCount) + m.Load("fr", &x.fr) + m.Load("sndCwnd", &x.sndCwnd) + m.Load("sndSsthresh", &x.sndSsthresh) + m.Load("sndCAAckCount", &x.sndCAAckCount) + m.Load("outstanding", &x.outstanding) + m.Load("sndWnd", &x.sndWnd) + m.Load("sndUna", &x.sndUna) + m.Load("sndNxt", &x.sndNxt) + m.Load("sndNxtList", &x.sndNxtList) + m.Load("rttMeasureSeqNum", &x.rttMeasureSeqNum) + m.Load("closed", &x.closed) + m.Load("writeNext", &x.writeNext) + m.Load("writeList", &x.writeList) + m.Load("rtt", &x.rtt) + m.Load("rto", &x.rto) + m.Load("maxPayloadSize", &x.maxPayloadSize) + m.Load("gso", &x.gso) + m.Load("sndWndScale", &x.sndWndScale) + m.Load("maxSentAck", &x.maxSentAck) + m.Load("state", &x.state) + m.Load("cc", &x.cc) + m.LoadValue("lastSendTime", new(unixTime), func(y interface{}) { x.loadLastSendTime(y.(unixTime)) }) + m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *rtt) beforeSave() {} +func (x *rtt) save(m state.Map) { + x.beforeSave() + m.Save("srtt", &x.srtt) + m.Save("rttvar", &x.rttvar) + m.Save("srttInited", &x.srttInited) +} + +func (x *rtt) afterLoad() {} +func (x *rtt) load(m state.Map) { + m.Load("srtt", &x.srtt) + m.Load("rttvar", &x.rttvar) + m.Load("srttInited", &x.srttInited) +} + +func (x *fastRecovery) beforeSave() {} +func (x *fastRecovery) save(m state.Map) { + x.beforeSave() + m.Save("active", &x.active) + m.Save("first", &x.first) + m.Save("last", &x.last) + m.Save("maxCwnd", &x.maxCwnd) + m.Save("highRxt", &x.highRxt) + m.Save("rescueRxt", &x.rescueRxt) +} + +func (x *fastRecovery) afterLoad() {} +func (x *fastRecovery) load(m state.Map) { + m.Load("active", &x.active) + m.Load("first", &x.first) + m.Load("last", &x.last) + m.Load("maxCwnd", &x.maxCwnd) + m.Load("highRxt", &x.highRxt) + m.Load("rescueRxt", &x.rescueRxt) +} + +func (x *unixTime) beforeSave() {} +func (x *unixTime) save(m state.Map) { + x.beforeSave() + m.Save("second", &x.second) + m.Save("nano", &x.nano) +} + +func (x *unixTime) afterLoad() {} +func (x *unixTime) load(m state.Map) { + m.Load("second", &x.second) + m.Load("nano", &x.nano) +} + +func (x *segmentList) beforeSave() {} +func (x *segmentList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *segmentList) afterLoad() {} +func (x *segmentList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *segmentEntry) beforeSave() {} +func (x *segmentEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *segmentEntry) afterLoad() {} +func (x *segmentEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("tcp.cubicState", (*cubicState)(nil), state.Fns{Save: (*cubicState).save, Load: (*cubicState).load}) + state.Register("tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load}) + state.Register("tcp.rcvBufAutoTuneParams", (*rcvBufAutoTuneParams)(nil), state.Fns{Save: (*rcvBufAutoTuneParams).save, Load: (*rcvBufAutoTuneParams).load}) + state.Register("tcp.EndpointInfo", (*EndpointInfo)(nil), state.Fns{Save: (*EndpointInfo).save, Load: (*EndpointInfo).load}) + state.Register("tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load}) + state.Register("tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load}) + state.Register("tcp.renoState", (*renoState)(nil), state.Fns{Save: (*renoState).save, Load: (*renoState).load}) + state.Register("tcp.SACKScoreboard", (*SACKScoreboard)(nil), state.Fns{Save: (*SACKScoreboard).save, Load: (*SACKScoreboard).load}) + state.Register("tcp.segment", (*segment)(nil), state.Fns{Save: (*segment).save, Load: (*segment).load}) + state.Register("tcp.segmentQueue", (*segmentQueue)(nil), state.Fns{Save: (*segmentQueue).save, Load: (*segmentQueue).load}) + state.Register("tcp.sender", (*sender)(nil), state.Fns{Save: (*sender).save, Load: (*sender).load}) + state.Register("tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load}) + state.Register("tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load}) + state.Register("tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load}) + state.Register("tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load}) + state.Register("tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load}) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go deleted file mode 100644 index bc5cfcf0e..000000000 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ /dev/null @@ -1,6342 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65535 - - // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an - // IPv4 endpoint when the MTU is set to defaultMTU in the test. - defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize -) - -func TestGiveUpConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - // Close the connection, wait for completion. - ep.Close() - - // Wait for ep to become writable. - <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { - t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) - } - - // Call Connect again to retreive the handshake failure status - // and stats updates. - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) - } - - if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } -} - -func TestConnectIncrementActiveConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) - } -} - -func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.FailedConnectionAttempts.Value() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) - } -} - -func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - want := stats.TCP.FailedConnectionAttempts.Value() + 1 - - if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) - } - - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) - } -} - -func TestTCPSegmentsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - // SYN and ACK - want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { - t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) - } -} - -func TestTCPResetsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - want := stats.TCP.SegmentsSent.Value() + 1 - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - // If the AckNum is not the increment of the last sequence number, a RST - // segment is sent back in response. - AckNum: c.IRS + 2, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.GetPacket() - if got := stats.TCP.ResetsSent.Value(); got != want { - t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) - } -} - -// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates -// a RST if an ACK is received on the listening socket for which there is no -// active handshake in progress and we are not using SYN cookies. -func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Lower stackwide TIME_WAIT timeout so that the reservations - // are released instantly on Close. - tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - c.GetPacket() - - // Since an active close was done we need to wait for a little more than - // tcpLingerTimeout for the port reservations to be released and the - // socket to move to a CLOSED state. - time.Sleep(20 * time.Millisecond) - - // Now resend the same ACK, this ACK should generate a RST as there - // should be no endpoint in SYN-RCVD state and we are not using - // syn-cookies yet. The reason we send the same ACK is we need a valid - // cookie(IRS) generated by the netstack without which the ACK will be - // rejected. - c.SendPacket(nil, ackHeaders) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) -} - -func TestTCPResetsReceivedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) - } -} - -func TestTCPResetsDoNotGenerateResets(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) - } - c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) -} - -func TestActiveHandshake(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) -} - -func TestNonBlockingClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) - } -} - -func TestConnectResetAfterClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLinger to 3 seconds so that sockets are marked closed - // after 3 second in FIN_WAIT2 state. - tcpLingerTimeout := 3 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) - } - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint, make sure we get a FIN segment, then acknowledge - // to complete closure of sender, but don't send our own FIN. - ep.Close() - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Wait for the ep to give up waiting for a FIN. - time.Sleep(tcpLingerTimeout + 1*time.Second) - - // Now send an ACK and it should trigger a RST as the endpoint should - // not exist anymore. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - for { - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { - // This is a retransmit of the FIN, ignore it. - continue - } - - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - ), - ) - break - } -} - -// TestClosingWithEnqueuedSegments tests handling of -// still enqueued segments when the endpoint transitions -// to StateClose. The in-flight segments would be re-enqueued -// to a any listening endpoint. -func TestClosingWithEnqueuedSegments(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Send a FIN for ESTABLISHED --> CLOSED-WAIT - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Get the ACK for the FIN we sent. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Close the application endpoint for CLOSE_WAIT --> LAST_ACK - ep.Close() - - // Get the FIN - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Pause the endpoint`s protocolMainLoop. - ep.(interface{ StopWork() }).StopWork() - - // Enqueue last ACK followed by an ACK matching the endpoint - // - // Send Last ACK for LAST_ACK --> CLOSED - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 791, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Send a packet with ACK set, this would generate RST when - // not using SYN cookies as in this test. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 792, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Unpause endpoint`s protocolMainLoop. - ep.(interface{ ResumeWork() }).ResumeWork() - - // Wait for the protocolMainLoop to resume and update state. - time.Sleep(10 * time.Millisecond) - - // Expect the endpoint to be closed. - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } - - // Check if the endpoint was moved to CLOSED and netstack a reset in - // response to the ACK packet that we sent after last-ACK. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - ), - ) -} - -func TestSimpleReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when -// creating a new active IPv4 TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnectV4(t *testing.T) { - const mtu = 5000 - const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, - }, - { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - c.Create(-1) - - // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) - } - - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - - // Start connection attempt to IPv4 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet with our user supplied MSS. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) - }) - } -} - -// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when -// creating a new active IPv6 TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnectV6(t *testing.T) { - const mtu = 5000 - const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, - }, - { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) - } - - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - - // Start connection attempt to IPv6 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet with our user supplied MSS. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) - }) - } -} - -func TestSendRstOnListenerRxSynAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) -} - -func TestSendRstOnListenerRxSynAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) -} - -// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv4. -func TestTCPAckBeforeAcceptV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) -} - -// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv6. -func TestTCPAckBeforeAcceptV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) -} - -func TestSendRstOnListenerRxAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(200))) -} - -func TestSendRstOnListenerRxAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true /* v6Only */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(200))) -} - -// TestListenShutdown tests for the listening endpoint not processing -// any receive when it is on read shutdown. -func TestListenShutdown(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatal("Shutdown failed:", err) - } - - // Wait for the endpoint state to be propagated. - time.Sleep(10 * time.Millisecond) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: 100, - AckNum: 200, - }) - - c.CheckNoPacket("Packet received when listening socket was shutdown") -} - -func TestTOSV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - - const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) - } - - var v tcpip.IPv4TOSOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Errorf("GetSockopt failed: %s", err) - } - - if want := tcpip.IPv4TOSOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) - } - - testV4Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestTrafficClassV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) - } - - var v tcpip.IPv6TrafficClassOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockopt failed: %s", err) - } - - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) - } - - // Test the connection request. - testV6Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetV6Packet() - checker.IPv6(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestConnectBindToDevice(t *testing.T) { - for _, test := range []struct { - name string - device string - want tcp.EndpointState - }{ - {"RightDevice", "nic1", tcp.StateEstablished}, - {"WrongDevice", "nic2", tcp.StateSynSent}, - {"AnyDevice", "", tcp.StateEstablished}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - bindToDevice := tcpip.BindToDeviceOption(test.device) - c.EP.SetSockOpt(bindToDevice) - // Start connection attempt. - waitEntry, _ := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - - c.GetPacket() - if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - }) - } -} - -func TestOutOfOrderReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send second half of data first, with seqnum 3 ahead of expected. - data := []byte{1, 2, 3, 4, 5, 6} - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that we get an ACK specifying which seqnum is expected. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait 200ms and check that no data has been received. - time.Sleep(200 * time.Millisecond) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send the first 3 bytes now. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive data. - read := make([]byte, 0, 6) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - if err == tcpip.ErrWouldBlock { - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - continue - } - t.Fatalf("Read failed: %v", err) - } - - read = append(read, v...) - } - - // Check that we received the data in proper order. - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that the whole data is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestOutOfOrderFlood(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create a new connection with initial window size of 10. - c.CreateConnected(789, 30000, 10) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send 100 packets before the actual one that is expected. - data := []byte{1, 2, 3, 4, 5, 6} - for i := 0; i < 100; i++ { - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 796, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Send packet with seqnum 793. It must be discarded because the - // out-of-order buffer was filled by the previous packets. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now send the expected packet, seqnum 790. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that only packet 790 is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestRstOnCloseWithUnreadData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now that we know we have unread data, let's just close the connection - // and verify that netstack sends an RST rather than a FIN. - c.EP.Close() - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // This final ACK should be ignored because an ACK on a reset doesn't mean - // anything. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) - - // Make sure we get the FIN but DON't ACK IT. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.SeqNum(uint32(c.IRS)+1), - )) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Cause a RST to be generated by closing the read end now since we have - // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) - - // Make sure we get the RST - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - checker.SeqNum(uint32(c.IRS)+2), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // The ACK to the FIN should now be rejected since the connection has been - // closed by a RST. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestShutdownRead(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) - } - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { - t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) - } -} - -func TestFullWindowReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, 10) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - _, _, err := c.EP.Read(nil) - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) - } - - // Fill up the window. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and window goes to zero. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), - ), - ) - - // Receive data and check it. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { - t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) - } - - // Check that we get an ACK for the newly non-zero window. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(10), - ), - ) -} - -func TestNoWindowShrinking(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Start off with a window size of 10, then shrink it to 5. - c.CreateConnected(789, 30000, 10) - - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send 3 bytes, check that the peer acknowledges them. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and that window doesn't go to zero - // just yet because it was previously set to 10. It must go to 7 now. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(7), - ), - ) - - // Send 7 more bytes, check that the window fills up. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), - ), - ) - - // Receive data and check it. - read := make([]byte, 0, 10) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - read = append(read, v...) - } - - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that we get an ACK for the newly non-zero window, which is the - // new size. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(5), - ), - ) -} - -func TestSimpleSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestZeroWindowSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 0, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Since the window is currently zero, check that no packet is received. - c.CheckNoPacket("Packet received when window is zero") - - // Open up the window. Data should be received now. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xbfff, - // that is, that it is scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestNonScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnected(789, 30000, 65535*3) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xbfff, - // that is, that it is scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestNonScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN - // should not carry the window scaling option. - c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestZeroScaledWindowReceive(t *testing.T) { - // This test ensures that the endpoint sends a non-zero window size - // advertisement when the scaled window transitions from 0 to non-zero, - // but the actual window (not scaled) hasn't gotten to zero. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size such that a window scale of 4 will be used. - const wnd = 65535 * 10 - const ws = uint32(4) - c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - // Write chunks of 50000 bytes. - remain := wnd - sent := 0 - data := make([]byte, 50000) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(remain>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Make the window non-zero, but the scaled window zero. - if remain >= 16 { - data = data[:remain-15] - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(0), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Read some data. An ack should be sent in response to that. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(len(v)>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestSegmentMerging(t *testing.T) { - tests := []struct { - name string - stop func(tcpip.Endpoint) - resume func(tcpip.Endpoint) - }{ - { - "stop work", - func(ep tcpip.Endpoint) { - ep.(interface{ StopWork() }).StopWork() - }, - func(ep tcpip.Endpoint) { - ep.(interface{ ResumeWork() }).ResumeWork() - }, - }, - { - "cork", - func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(1)) - }, - func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(0)) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Prevent the endpoint from processing packets. - test.stop(c.EP) - - var allData []byte - for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - // Let the endpoint process the segments that we just sent. - test.resume(c.EP) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(allData)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { - t.Fatalf("got data = %v, want = %v", got, allData) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), - RcvWnd: 30000, - }) - }) - } -} - -func TestDelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - c.EP.SetSockOptInt(tcpip.DelayOption, 1) - - var allData []byte - for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - for _, want := range [][]byte{allData[:1], allData[1:]} { - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(want)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { - t.Fatalf("got data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(want))) - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) - } -} - -func TestUndelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - c.EP.SetSockOptInt(tcpip.DelayOption, 1) - - allData := [][]byte{{0}, {1, 2, 3}} - for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - - // Check that data is received. - first := c.GetPacket() - checker.IPv4(t, first, - checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { - t.Fatalf("got first packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[0]))) - - // Check that we don't get the second packet yet. - c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - - c.EP.SetSockOptInt(tcpip.DelayOption, 0) - - // Check that data is received. - second := c.GetPacket() - checker.IPv4(t, second, - checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { - t.Fatalf("got second packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[1]))) - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) -} - -func TestMSSNotDelayed(t *testing.T) { - tests := []struct { - name string - fn func(tcpip.Endpoint) - }{ - {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const maxPayload = 100 - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - test.fn(c.EP) - - allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} - for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - - for i, data := range allData { - // Check that data is received. - packet := c.GetPacket() - checker.IPv4(t, packet, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { - t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) - } - - seq = seq.Add(seqnum.Size(len(data))) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) - }) - } -} - -func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { - payloadMultiplier := 10 - dataLen := payloadMultiplier * maxPayload - data := make([]byte, dataLen) - for i := range data { - data[i] = byte(i) - } - - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received in chunks. - bytesReceived := 0 - numPackets := 0 - for bytesReceived != dataLen { - b := c.GetPacket() - numPackets++ - tcpHdr := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcpHdr.Payload()) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { - t.Fatalf("got data = %v, want = %v", p, pdata) - } - bytesReceived += payloadLen - var options []byte - if c.TimeStampEnabled { - // If timestamp option is enabled, echo back the timestamp and increment - // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcpHdr.ParsedOptions() - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) - options = tsOpt[:] - } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options, - }) - } - if numPackets == 1 { - t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") - } -} - -func TestSendGreaterThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSetTTL(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := context.New(t, 65535) - defer c.Cleanup() - - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) - } - - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - - checker.IPv4(t, b, checker.TTL(wantTTL)) - }) - } -} - -func TestActiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, 65535) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestPassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 536 - const mtu = 2000 - c := context.New(t, mtu) - defer c.Cleanup() - - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestForwarderSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan *tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err *tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Wait for connection to be available. - select { - case err := <-ch: - if err != nil { - t.Fatalf("Error creating endpoint: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynOptionsOnActiveConnect(t *testing.T) { - const mtu = 1400 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - const wndScale = 2 - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - // Receive SYN packet. - b := c.GetPacket() - mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - // Wait for retransmit. - time.Sleep(1 * time.Second) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcpHdr.SourcePort()), - checker.SeqNum(tcpHdr.SequenceNumber()), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - // Send SYN-ACK. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestCloseListener(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Close the listener and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) - } -} - -func TestReceiveOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - RcvWnd: 30000, - }) - - // Try to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - -loop: - for { - switch _, _, err := c.EP.Read(nil); err { - case tcpip.ErrWouldBlock: - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") - } - case tcpip.ErrConnectionReset: - break loop - default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) - } - } - // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) - } - if tcp.EndpointState(c.EP.State()) != tcp.StateError { - t.Fatalf("got EP state is not StateError") - } - - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } -} - -func TestSendOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - RcvWnd: 30000, - }) - - // Wait for the RST to be received. - time.Sleep(1 * time.Second) - - // Try to write. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset) - } -} - -func TestFinImmediately(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Don't acknowledge yet. We should get a retransmit of the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithNoPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and have it acknowledged. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Shutdown, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingDataCwndFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write enough segments to fill the congestion window before ACK'ing - // any of them. - view := buffer.NewView(10) - for i := tcp.InitialCwnd; i > 0; i-- { - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - } - - next := uint32(c.IRS) + 1 - for i := tcp.InitialCwnd; i > 0; i-- { - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - } - - // Shutdown the connection, check that the FIN segment isn't sent - // because the congestion window doesn't allow it. Wait until a - // retransmit is received. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Send the ACK that will allow the FIN to be sent as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPartialAck(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. Also send - // FIN from the test side. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that we get an ACK for the fin. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send an ACK for the data, but not for the FIN yet. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 791, - AckNum: seqnum.Value(next - 1), - RcvWnd: 30000, - }) - - // Check that we don't get a retransmit of the FIN. - c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) - - // Ack the FIN. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 791, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) -} - -func TestUpdateListenBacklog(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Update the backlog with another Listen() on the same endpoint. - if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %v", err) - } - - ep.Close() -} - -func scaledSendWindow(t *testing.T, scale uint8) { - // This test ensures that the endpoint is using the right scaling by - // sending a buffer that is larger than the window size, and ensuring - // that the endpoint doesn't send more than allowed. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - header.TCPOptionWS, 3, scale, header.TCPOptionNOP, - }) - - // Open up the window with a scaled value. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 1, - }) - - // Send some data. Check that it's capped by the window size. - view := buffer.NewView(65535) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that only data that fits in the scaled window is sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen((1<<scale)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Reset the connection to free resources. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - }) -} - -func TestScaledSendWindow(t *testing.T) { - for scale := uint8(0); scale <= 14; scale++ { - scaledSendWindow(t, scale) - } -} - -func TestReceivedValidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ValidSegmentsReceived.Value() + 1 - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { - t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) - } - // Ensure there were no errors during handshake. If these stats have - // incremented, then the connection should not have been established. - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) - } - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) - } -} - -func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.InvalidSegmentsReceived.Value() + 1 - vv := c.BuildSegment(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] - tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 - - c.SendSegment(vv) - - if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) - } -} - -func TestReceivedIncorrectChecksumIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ChecksumErrors.Value() + 1 - vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] - // Overwrite a byte in the payload which should cause checksum - // verification to fail. - tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 - - c.SendSegment(vv) - - if got := stats.TCP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) - } -} - -func TestReceivedSegmentQueuing(t *testing.T) { - // This test sends 200 segments containing a few bytes each to an - // endpoint and checks that they're all received and acknowledged by - // the endpoint, that is, that none of the segments are dropped by - // internal queues. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - // Send 200 segments. - data := []byte{1, 2, 3} - for i := 0; i < 200; i++ { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + i*len(data)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - } - - // Receive ACKs for all segments. - last := seqnum.Value(790 + 200*len(data)) - for { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - tcpHdr := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcpHdr.AckNumber()) - if ack == last { - break - } - - if last.LessThan(ack) { - t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) - } - } -} - -func TestReadAfterClosedState(t *testing.T) { - // This test ensures that calling Read() or Peek() after the endpoint - // has transitioned to closedState still works if there is pending - // data. To transition to stateClosed without calling Close(), we must - // shutdown the send path and the peer must send its own FIN. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) - } - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) - } - - // Shutdown immediately for write, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Send some data and acknowledge the FIN. - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(uint32(791+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack the chance to transition to closed state from - // TIME_WAIT. - time.Sleep(tcpTimeWaitTimeout * 2) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that peek works. - peekBuf := make([]byte, 10) - n, _, err := c.EP.Peek([][]byte{peekBuf}) - if err != nil { - t.Fatalf("Peek failed: %s", err) - } - - peekBuf = peekBuf[:n] - if !bytes.Equal(data, peekBuf) { - t.Fatalf("got data = %v, want = %v", peekBuf, data) - } - - // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Now that we drained the queue, check that functions fail with the - // right error code. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) - } - - if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) - } -} - -func TestReusePort(t *testing.T) { - // This test ensures that ports are immediately available for reuse - // after Close on the endpoints using them returns. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // First case, just an endpoint that was bound. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - c.EP.Close() - - // Second case, an endpoint that was bound and is connecting.. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - c.EP.Close() - - // Third case, an endpoint that was bound and is listening. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } -} - -func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - if int(s) != v { - t.Fatalf("got receive buffer size = %v, want = %v", s, v) - } -} - -func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - if int(s) != v { - t.Fatalf("got send buffer size = %v, want = %v", s, v) - } -} - -func TestDefaultBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer func() { - if ep != nil { - ep.Close() - } - }() - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) -} - -func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer ep.Close() - - // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Set values below the min. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkRecvBufferSize(t, ep, 200) - - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkSendBufferSize(t, ep, 300) - - // Set values above the max. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) - - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) - - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer ep.Close() - - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) - } - - // strPtr is used instead of taking the address of string literals, which is - // a compiler error. - strPtr := func(s string) *string { - return &s - } - - testActions := []struct { - name string - setBindToDevice *string - setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption - }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) - } - } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) - } - }) - } -} - -func makeStack() (*stack.Stack, *tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - ipv6.NewProtocol(), - }, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - id := loopback.New() - if testing.Verbose() { - id = sniffer.New(id) - } - - if err := s.CreateNIC(1, id); err != nil { - return nil, err - } - - for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address - }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, - } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { - return nil, err - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return s, nil -} - -func TestSelfConnect(t *testing.T) { - // This test ensures that intentional self-connects work. In particular, - // it checks that if an endpoint binds to say 127.0.0.1:1000 then - // connects to 127.0.0.1:1000, then it will be connected to itself, and - // is able to send and receive data through the same endpoint. - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Write something. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Read back what was written. - wq.EventUnregister(&waitEntry) - wq.EventRegister(&waitEntry, waiter.EventIn) - rd, _, err := ep.Read(nil) - if err != nil { - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) - } - <-notifyCh - rd, _, err = ep.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - } - - if !bytes.Equal(data, rd) { - t.Fatalf("got data = %v, want = %v", rd, data) - } -} - -func TestConnectAvoidsBoundPorts(t *testing.T) { - addressTypes := func(t *testing.T, network string) []string { - switch network { - case "ipv4": - return []string{"v4"} - case "ipv6": - return []string{"v6"} - case "dual": - return []string{"v6", "mapped"} - default: - t.Fatalf("unknown network: '%s'", network) - } - - panic("unreachable") - } - - address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { - switch addressType { - case "v4": - if isAny { - return "" - } - return context.StackAddr - case "v6": - if isAny { - return "" - } - return context.StackV6Addr - case "mapped": - if isAny { - return context.V4MappedWildcardAddr - } - return context.StackV4MappedAddr - default: - t.Fatalf("unknown address type: '%s'", addressType) - } - - panic("unreachable") - } - // This test ensures that Endpoint.Connect doesn't select already-bound ports. - networks := []string{"ipv4", "ipv6", "dual"} - for _, exhaustedNetwork := range networks { - t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { - for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { - t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { - for _, isAny := range []bool{false, true} { - t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { - for _, candidateNetwork := range networks { - t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { - for _, candidateAddressType := range addressTypes(t, candidateNetwork) { - t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - var eps []tcpip.Endpoint - defer func() { - for _, ep := range eps { - ep.Close() - } - }() - makeEP := func(network string) tcpip.Endpoint { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch network { - case "ipv4": - networkProtocolNumber = ipv4.ProtocolNumber - case "ipv6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatalf("unknown network: '%s'", network) - } - ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - eps = append(eps, ep) - switch network { - case "ipv4": - case "ipv6": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) - } - case "dual": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) - } - default: - t.Fatalf("unknown network: '%s'", network) - } - return ep - } - - var v4reserved, v6reserved bool - switch exhaustedAddressType { - case "v4", "mapped": - v4reserved = true - case "v6": - v6reserved = true - // Dual stack sockets bound to v6 any reserve on v4 as - // well. - if isAny { - switch exhaustedNetwork { - case "ipv6": - case "dual": - v4reserved = true - default: - t.Fatalf("unknown address type: '%s'", exhaustedNetwork) - } - } - default: - t.Fatalf("unknown address type: '%s'", exhaustedAddressType) - } - var collides bool - switch candidateAddressType { - case "v4", "mapped": - collides = v4reserved - case "v6": - collides = v6reserved - default: - t.Fatalf("unknown address type: '%s'", candidateAddressType) - } - - for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { - if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %v", i, err) - } - } - want := tcpip.ErrConnectStarted - if collides { - want = tcpip.ErrNoPortAvailable - } - if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want) - } - }) - } - }) - } - }) - } - }) - } - }) - } -} - -func TestPathMTUDiscovery(t *testing.T) { - // This test verifies the stack retransmits packets after it receives an - // ICMP packet indicating that the path MTU has been exceeded. - c := context.New(t, 1500) - defer c.Cleanup() - - // Create new connection with MSS of 1460. - const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - // Send 3200 bytes of data. - const writeSize = 3200 - data := buffer.NewView(writeSize) - for i := range data { - data[i] = byte(i) - } - - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { - var ret []byte - for i, size := range sizes { - p := c.GetPacket() - if i == which { - ret = p - } - checker.IPv4(t, p, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(seqNum), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - seqNum += uint32(size) - } - return ret - } - - // Receive three packets. - sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} - first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) - - // Send "packet too big" messages back to netstack. - const newMTU = 1200 - const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize - mtu := []byte{0, 0, newMTU / 256, newMTU % 256} - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) - - // See retransmitted packets. None exceeding the new max. - sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} - receivePackets(c, sizes, -1, uint32(c.IRS)+1) -} - -func TestTCPEndpointProbe(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - invoked := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint ID is what we expect. - // - // We don't do an extensive validation of every field but a - // basic sanity test. - if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { - t.Fatalf("got LocalAddress: %q, want: %q", got, want) - } - if got, want := state.ID.LocalPort, c.Port; got != want { - t.Fatalf("got LocalPort: %d, want: %d", got, want) - } - if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { - t.Fatalf("got RemoteAddress: %q, want: %q", got, want) - } - if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { - t.Fatalf("got RemotePort: %d, want: %d", got, want) - } - - invoked <- struct{}{} - }) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-invoked: - case <-time.After(100 * time.Millisecond): - t.Fatalf("TCP Probe function was not called") - } -} - -func TestStackSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err *tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - var oldCC tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - - got, want := cc, oldCC - // If SetTransportProtocolOption is expected to succeed - // then the returned value for congestion control should - // match the one specified in the - // SetTransportProtocolOption call above, else it should - // be what it was before the call to - // SetTransportProtocolOption. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } -} - -func TestStackAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Query permitted congestion control algorithms. - var aCC tcpip.AvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) - } - if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestStackSetAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.AvailableCongestionControlOption("xyz") - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) - } - - // Verify that we still get the expected list of congestion control options. - var cc tcpip.AvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestEndpointSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err *tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, - } - - for _, connected := range []bool{false, true} { - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - var oldCC tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) - } - - if connected { - c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) - } - - if err := c.EP.SetSockOpt(tc.cc); err != tc.err { - t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) - } - - got, want := cc, oldCC - // If SetSockOpt is expected to succeed then the - // returned value for congestion control should match - // the one specified in the SetSockOpt above, else it - // should be what it was before the call to SetSockOpt. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } - } -} - -func enableCUBIC(t *testing.T, c *context.Context) { - t.Helper() - opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) - } -} - -func TestKeepalive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) - - // 5 unacked keepalives are sent. ACK each one, and check that the - // connection stays alive after 5. - for i := 0; i < 10; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Acknowledge the keepalive. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS, - RcvWnd: 30000, - }) - } - - // Check that the connection is still alive. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send some data and wait before ACKing it. Keepalives should be disabled - // during this period. - view := buffer.NewView(3) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Wait for the packet to be retransmitted. Verify that no keepalives - // were sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - c.CheckNoPacket("Keepalive packet received while unACKed data is pending") - - next += uint32(len(view)) - - // Send ACK. Keepalives should start sending again. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Now receive 5 keepalives, but don't ACK them. The connection - // should be reset after 5. - for i := 0; i < 5; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next-1)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // The connection should be terminated after 5 unacked keepalives. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - ), - ) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) - } - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } -} - -func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - // Send a SYN request. - irs = seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptions(), - })) - } - - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - // Send a SYN request. - irs = seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptionsV6(), - })) - } - - checker.IPv6(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -// TestListenBacklogFull tests that netstack does not complete handshakes if the -// listen backlog for the endpoint is full. -func TestListenBacklogFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 2 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - for i := 0; i < listenBacklog; i++ { - executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) - } - - time.Sleep(50 * time.Millisecond) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() - if err != tcpip.ErrWouldBlock { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } - - // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - - newEP, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a -// non unicast IPv4 address are not accepted. -func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpip.Address("\xe0\x00\x01\x02") - otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - "SourceUnspecified", - header.IPv4Any, - context.StackAddr, - }, - { - "SourceBroadcast", - header.IPv4Broadcast, - context.StackAddr, - }, - { - "SourceOurMulticast", - multicastAddr, - context.StackAddr, - }, - { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackAddr, - }, - { - "DestUnspecified", - context.TestAddr, - header.IPv4Any, - }, - { - "DestBroadcast", - context.TestAddr, - header.IPv4Broadcast, - }, - { - "DestOurMulticast", - context.TestAddr, - multicastAddr, - }, - { - "DestOtherMulticast", - context.TestAddr, - otherMulticastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(789) - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestAddr, context.StackAddr) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) - }) - } -} - -// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a -// non unicast IPv6 address are not accepted. -func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - "SourceUnspecified", - header.IPv6Any, - context.StackV6Addr, - }, - { - "SourceAllNodes", - header.IPv6AllNodesMulticastAddress, - context.StackV6Addr, - }, - { - "SourceOurMulticast", - multicastAddr, - context.StackV6Addr, - }, - { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackV6Addr, - }, - { - "DestUnspecified", - context.TestV6Addr, - header.IPv6Any, - }, - { - "DestAllNodes", - context.TestV6Addr, - header.IPv6AllNodesMulticastAddress, - }, - { - "DestOurMulticast", - context.TestV6Addr, - multicastAddr, - }, - { - "DestOtherMulticast", - context.TestV6Addr, - otherMulticastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(789) - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestV6Addr, context.StackV6Addr) - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) - }) - } -} - -func TestListenSynRcvdQueueFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send two SYN's the first one should get a SYN-ACK, the - // second one should not get any response and is dropped as - // the synRcvd count will be equal to backlog. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - // - // NOTE: we did not complete the handshake for the previous one so the - // accept backlog should be empty and there should be one connection in - // synRcvd state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Now complete the previous connection and verify that there is a connection - // to accept. - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - newEP, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestListenBacklogFullSynCookieInUse(t *testing.T) { - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 1 - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 1 - portOffset := uint16(0) - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - executeHandshake(t, c, context.TestPort+portOffset, false) - portOffset++ - // Wait for this to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - // pick a different src port for new SYN. - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - // The Syn should be dropped as the endpoint's backlog is full. - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Verify that there is only one acceptable connection at this point. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() - if err != tcpip.ErrWouldBlock { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } -} - -func TestSynRcvdBadSeqNumber(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcpHdr.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now send a packet with an out-of-window sequence number - largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: largeSeqnum, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Should receive an ACK with the expected SEQ number - b = c.GetPacket() - tcpCheckers = []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.AckNum(uint32(irs) + 1), - checker.SeqNum(uint32(iss + 1)), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now that the socket replied appropriately with the ACK, - // complete the connection to test that the large SEQ num - // did not change the state from SYN-RCVD. - - // Send ACK to move to ESTABLISHED state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - newEP, _, err := c.EP.Accept() - - if err != nil && err != tcpip.ErrWouldBlock { - t.Fatalf("Accept failed: %s", err) - } - - if err == tcpip.ErrWouldBlock { - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - - if err != nil { - t.Fatalf("Write failed: %s", err) - } - - pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) - if string(tcpHdr.Payload()) != data { - t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) - } -} - -func TestPassiveConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - stats := c.Stack().Stats() - want := stats.TCP.PassiveConnectionOpenings.Value() + 1 - - srcPort := uint16(context.TestPort) - executeHandshake(t, c, srcPort+1, false) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) - } -} - -func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - srcPort := uint16(context.TestPort) - // Now attempt a handshakes it will fill up the accept backlog. - executeHandshake(t, c, srcPort, false) - - // Give time for the final ACK to be processed as otherwise the next handshake could - // get accepted before the previous one based on goroutine scheduling. - time.Sleep(50 * time.Millisecond) - - want := stats.TCP.ListenOverflowSynDrop.Value() + 1 - - // Now we will send one more SYN and this one should get dropped - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestEndpointBindListenAcceptState(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Expect InvalidEndpointState errors on a read at this point. - if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) - } - if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { - t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - aep, _, err := ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - aep, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { - t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) - } - // Listening endpoint remains in listen state. - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - ep.Close() - // Give worker goroutines time to receive the close notification. - time.Sleep(1 * time.Second) - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - // Accepted endpoint remains open when the listen endpoint is closed. - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - -} - -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. -func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - // Change the expected window scale to match the value needed for the - // maximum buffer size defined above. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - // NOTE: The timestamp values in the sent packets are meaningless to the - // peer so we just increment the timestamp value by 1 every batch as we - // are not really using them for anything. Send a single byte to verify - // the advertised window. - tsVal := rawEP.TSVal + 1 - - // Introduce a 25ms latency by delaying the first byte. - latency := 25 * time.Millisecond - time.Sleep(latency) - rawEP.SendPacketWithTS([]byte{1}, tsVal) - - // Verify that the ACK has the expected window. - wantRcvWnd := receiveBufferSize - wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale)) - rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1)) - time.Sleep(25 * time.Millisecond) - - // Allocate a large enough payload for the test. - b := make([]byte, int(receiveBufferSize)*2) - offset := 0 - payloadSize := receiveBufferSize - 1 - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := offset - end := offset + payloadSize - packetsSent := 0 - for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - packetsSent++ - } - // Resume the worker so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Since we read no bytes the window should goto zero till the - // application reads some of the data. - // Discard all intermediate acks except the last one. - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } - } - rawEP.VerifyACKRcvWnd(0) - - time.Sleep(25 * time.Millisecond) - // Verify that sending more data when window is closed is dropped and - // not acked. - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - - // Verify that the stack sends us back an ACK with the sequence number - // of the last packet sent indicating it was dropped. - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - checker.Window(0), - )) - - // Now read all the data from the endpoint and verify that advertised - // window increases to the full available buffer size. - for { - _, _, err := c.EP.Read(nil) - if err == tcpip.ErrWouldBlock { - break - } - } - - // Verify that we receive a non-zero window update ACK. When running - // under thread santizer this test can end up sending more than 1 - // ack, 1 for the non-zero window - p = c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd) - } - }, - )) -} - -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. -func TestReceiveBufferAutoTuning(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - // Enable Auto-tuning. - stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - // Change the expected window scale to match the value needed for the - // maximum buffer size used by stack. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - wantRcvWnd := receiveBufferSize - scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) - } - // Allocate a large array to send to the endpoint. - b := make([]byte, receiveBufferSize*48) - - // In every iteration we will send double the number of bytes sent in - // the previous iteration and read the same from the app. The received - // window should grow by at least 2x of bytes read by the app in every - // RTT. - offset := 0 - payloadSize := receiveBufferSize / 8 - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - tsVal := rawEP.TSVal - // We are going to do our own computation of what the moderated receive - // buffer should be based on sent/copied data per RTT and verify that - // the advertised window by the stack matches our calculations. - prevCopied := 0 - done := false - latency := 1 * time.Millisecond - for i := 0; !done; i++ { - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := offset - end := offset + payloadSize - totalSent := 0 - packetsSent := 0 - for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - totalSent += mss - packetsSent++ - } - // Resume it so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Give 1ms for the worker to process the packets. - time.Sleep(1 * time.Millisecond) - - // Verify that the advertised window on the ACK is reduced by - // the total bytes sent. - expectedWnd := wantRcvWnd - totalSent - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } - } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd)) - - // Now read all the data from the endpoint and invoke the - // moderation API to allow for receive buffer auto-tuning - // to happen before we measure the new window. - totalCopied := 0 - for { - b, _, err := c.EP.Read(nil) - if err == tcpip.ErrWouldBlock { - break - } - totalCopied += len(b) - } - - // Invoke the moderation API. This is required for auto-tuning - // to happen. This method is normally expected to be invoked - // from a higher layer than tcpip.Endpoint. So we simulate - // copying to user-space by invoking it explicitly here. - c.EP.ModerateRecvBuf(totalCopied) - - // Now send a keep-alive packet to trigger an ACK so that we can - // measure the new window. - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - - if i == 0 { - // In the first iteration the receiver based RTT is not - // yet known as a result the moderation code should not - // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd)) - prevCopied = totalCopied - } else { - rttCopied := totalCopied - if i == 1 { - // The moderation code accumulates copied bytes till - // RTT is established. So add in the bytes sent in - // the first iteration to the total bytes for this - // RTT. - rttCopied += prevCopied - // Now reset it to the initial value used by the - // auto tuning logic. - prevCopied = tcp.InitialCwnd * mss * 2 - } - newWnd := rttCopied<<1 + 16*mss - grow := (newWnd * (rttCopied - prevCopied)) / prevCopied - newWnd += (grow << 1) - if newWnd > maxReceiveBufferSize { - newWnd = maxReceiveBufferSize - done = true - } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd)) - wantRcvWnd = newWnd - prevCopied = rttCopied - // Increase the latency after first two iterations to - // establish a low RTT value in the receiver since it - // only tracks the lowest value. This ensures that when - // ModerateRcvBuf is called the elapsed time is always > - // rtt. Without this the test is flaky due to delays due - // to scheduling/wakeup etc. - latency += 50 * time.Millisecond - } - time.Sleep(latency) - offset += payloadSize - payloadSize *= 2 - } -} - -func TestDelayEnabled(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - checkDelayOption(t, c, false, 0) // Delay is disabled by default. - - for _, v := range []struct { - delayEnabled tcp.DelayEnabled - wantDelayOption int - }{ - {delayEnabled: false, wantDelayOption: 0}, - {delayEnabled: true, wantDelayOption: 1}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) - } -} - -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { - t.Helper() - - var gotDelayEnabled tcp.DelayEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { - t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err) - } - if gotDelayEnabled != wantDelayEnabled { - t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) - } - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) - if err != nil { - t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) - } - gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) - if err != nil { - t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) - } - if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) - } -} - -func TestTCPLingerTimeout(t *testing.T) { - c := context.New(t, 1500 /* mtu */) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - testCases := []struct { - name string - tcpLingerTimeout time.Duration - want time.Duration - }{ - {"NegativeLingerTimeout", -123123, 0}, - {"ZeroLingerTimeout", 0, 0}, - {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, - // Values > stack's TCPLingerTimeout are capped to the stack's - // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { - t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) - } - var v tcpip.TCPLingerTimeoutOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) - } - if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) - } - }) - } -} - -func TestTCPTimeWaitRSTIgnored(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now send a RST and this should be ignored and not - // generate an ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - }) - - c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitOutOfOrder(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitNewSyn(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Send a SYN request w/ sequence number lower than - // the highest sequence number sent. We just reuse - // the same number. - iss = seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) - - // Send a SYN request w/ sequence number higher than - // the highest sequence number sent. - iss = seqnum.Value(792) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) - } - - want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - time.Sleep(2 * time.Second) - - // Now send a duplicate FIN. This should cause the TIME_WAIT to extend - // by another 5 seconds and also send us a duplicate ACK as it should - // indicate that the final ACK was potentially lost. - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Sleep for 4 seconds so at this point we are 1 second past the - // original tcpLingerTimeout of 5 seconds. - time.Sleep(4 * time.Second) - - // Send an ACK and it should not generate any packet as the socket - // should still be in TIME_WAIT for another another 5 seconds due - // to the duplicate FIN we sent earlier. - *ackHeaders = *finHeaders - ackHeaders.SeqNum = ackHeaders.SeqNum + 1 - ackHeaders.Flags = header.TCPFlagAck - c.SendPacket(nil, ackHeaders) - - c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) - // Now sleep for another 2 seconds so that we are past the - // extended TIME_WAIT of 7 seconds (2 + 5). - time.Sleep(2 * time.Second) - - // Resend the same ACK. - c.SendPacket(nil, ackHeaders) - - // Receive the RST that should be generated as there is no valid - // endpoint. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) - } -} - -func TestTCPCloseWithData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) - } - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: 30000, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now trigger a passive close by sending a FIN. - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - RcvWnd: 30000, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now write a few bytes and then close the endpoint. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b = c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } - - c.EP.Close() - // Check the FIN. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.AckNum(uint32(iss+2)), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - // First send a partial ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now send a full ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now ACK the FIN. - ackHeaders.AckNum++ - c.SendPacket(nil, ackHeaders) - - // Now send an ACK and we should get a RST back as the endpoint should - // be in CLOSED state. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Check the RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) - -} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go deleted file mode 100644 index a641e953d..000000000 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -// createConnectedWithTimestampOption creates and connects c.ep with the -// timestamp option enabled. -func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) -} - -// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on -// an active connect and sets the TS Echo Reply fields correctly when the -// SYN-ACK also indicates support for the TS option and provides a TSVal. -func TestTimeStampEnabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read and validate that we have data to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // The following tests ensure that TS option once enabled behaves - // correctly as described in - // https://tools.ietf.org/html/rfc7323#section-4.3. - // - // We are not testing delayed ACKs here, but we do test out of order - // packet delivery and filling the sequence number hole created due to - // the out of order packet. - // - // The test also verifies that the sequence numbers and timestamps are - // as expected. - data := []byte{1, 2, 3} - - // First we increment tsVal by a small amount. - tsVal := rep.TSVal + 100 - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Next we send an out of order packet. - rep.NextSeqNum += 3 - tsVal += 200 - rep.SendPacketWithTS(data, tsVal) - - // The ACK should contain the original sequenceNumber and an older TS. - rep.NextSeqNum -= 6 - rep.VerifyACKWithTS(tsVal - 200) - - // Next we fill the hole and the returned ACK should contain the - // cumulative sequence number acking all data sent till now and have the - // latest timestamp sent below in its TSEcr field. - tsVal -= 100 - rep.SendPacketWithTS(data, tsVal) - rep.NextSeqNum += 3 - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal by a large value that doesn't result in a wrap around. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal again by a large value which should cause the - // timestamp value to wrap around. The returned ACK should contain the - // wrapped around timestamp in its tsEcr field and not the tsVal from - // the previous packet sent above. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // There should be 5 views to read and each of them should - // contain the same data. - for i := 0; i < 5; i++ { - got, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if want := data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } - } -} - -// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an -// active connect but if the SYN-ACK doesn't specify the TS option then -// timestamp option is not enabled and future packets do not contain a -// timestamp. -func TestTimeStampDisabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithOptions(header.TCPSynOptions{}) -} - -func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - c := context.New(t, defaultMTU) - defer c.Cleanup() - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - tsVal := rand.Uint32() - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) - - // Now send some data and validate that timestamp is echoed correctly in the ACK. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) - } - - // Check that data is received and that the timestamp option TSEcr field - // matches the expected value. - b := c.GetPacket() - checker.IPv4(t, b, - // Add 12 bytes for the timestamp option + 2 NOPs to align at 4 - // byte boundary. - checker.PayloadLen(len(data)+header.TCPMinimumSize+12), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - checker.TCPTimestampChecker(true, 0, tsVal+1), - ), - ) -} - -// TestTimeStampEnabledAccept tests that if the SYN on a passive connect -// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK -// and echoes the tsVal field of the original SYN in the tcEcr field of the -// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify -// that Timestamp option is enabled in both cases if requested in the original -// SYN. -func TestTimeStampEnabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. - } - for _, tc := range testCases { - timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now send some data with the accepted connection endpoint and validate - // that no timestamp option is sent in the TCP segment. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) - } - - // Check that data is received and that the timestamp option is disabled - // when SYN cookies are enabled/disabled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - checker.TCPTimestampChecker(false, 0, 0), - ), - ) -} - -// TestTimeStampDisabledAccept tests that Timestamp option is not used when the -// peer doesn't advertise it and connection is established with Accept(). -func TestTimeStampDisabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. - } - for _, tc := range testCases { - timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func TestSendGreaterThanMTUWithOptions(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - createConnectedWithTimestampOption(c) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - droppedPacketsStat := c.Stack().Stats().DroppedPackets - droppedPackets := droppedPacketsStat.Value() - data := []byte{1, 2, 3} - // Send a packet with no TCP options/timestamp. - rep.SendPacket(data, nil) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Assert that DroppedPackets was not incremented. - if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { - t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) - } - - // Issue a read and we should data. - got, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if want := data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD deleted file mode 100644 index b33ec2087..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - testonly = 1, - srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go deleted file mode 100644 index b0a376eba..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ /dev/null @@ -1,1100 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provides a test context for use in tcp tests. It also -// provides helper methods to assert/check certain behaviours. -package context - -import ( - "bytes" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // StackAddr is the IPv4 address assigned to the stack. - StackAddr = "\x0a\x00\x00\x01" - - // StackPort is used as the listening port in tests for passive - // connects. - StackPort = 1234 - - // TestAddr is the source address for packets sent to the stack via the - // link layer endpoint. - TestAddr = "\x0a\x00\x00\x02" - - // TestPort is the TCP port used for packets sent to the stack - // via the link layer endpoint. - TestPort = 4096 - - // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - - // TestV6Addr is the source address for packets sent to the stack via - // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - // StackV4MappedAddr is StackAddr as a mapped v6 address. - StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr - - // TestV4MappedAddr is TestAddr as a mapped v6 address. - TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr - - // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - // testInitialSequenceNumber is the initial sequence number sent in packets that - // are sent in response to a SYN or in the initial SYN sent to the stack. - testInitialSequenceNumber = 789 -) - -// Headers is used to represent the TCP header fields when building a -// new packet. -type Headers struct { - // SrcPort holds the src port value to be used in the packet. - SrcPort uint16 - - // DstPort holds the destination port value to be used in the packet. - DstPort uint16 - - // SeqNum is the value of the sequence number field in the TCP header. - SeqNum seqnum.Value - - // AckNum represents the acknowledgement number field in the TCP header. - AckNum seqnum.Value - - // Flags are the TCP flags in the TCP header. - Flags int - - // RcvWnd is the window to be advertised in the ReceiveWindow field of - // the TCP header. - RcvWnd seqnum.Size - - // TCPOpts holds the options to be sent in the option field of the TCP - // header. - TCPOpts []byte -} - -// Context provides an initialized Network stack and a link layer endpoint -// for use in TCP tests. -type Context struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - // IRS holds the initial sequence number in the SYN sent by endpoint in - // case of an active connect or the sequence number sent by the endpoint - // in the SYN-ACK sent in response to a SYN when listening in passive - // mode. - IRS seqnum.Value - - // Port holds the port bound by EP below in case of an active connect or - // the listening port number in case of a passive connect. - Port uint16 - - // EP is the test endpoint in the stack owned by this context. This endpoint - // is used in various tests to either initiate an active connect or is used - // as a passive listening endpoint to accept inbound connections. - EP tcpip.Endpoint - - // Wq is the wait queue associated with EP and is used to block for events - // on EP. - WQ waiter.Queue - - // TimeStampEnabled is true if ep is connected with the timestamp option - // enabled. - TimeStampEnabled bool - - // WindowScale is the expected window scale in SYN packets sent by - // the stack. - WindowScale uint8 -} - -// New allocates and initializes a test context containing a new -// stack and a link-layer endpoint. -func New(t *testing.T, mtu uint32) *Context { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, - }) - - // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Some of the congestion control tests send up to 640 packets, we so - // set the channel size to 1000. - ep := channel.New(1000, mtu, "") - wep := stack.LinkEndpoint(ep) - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) - if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, mtu, "")) - } - if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &Context{ - t: t, - s: s, - linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), - } -} - -// Cleanup closes the context endpoint if required. -func (c *Context) Cleanup() { - if c.EP != nil { - c.EP.Close() - } -} - -// Stack returns a reference to the stack in the Context. -func (c *Context) Stack() *stack.Stack { - return c.s -} - -// CheckNoPacketTimeout verifies that no packet is received during the time -// specified by wait. -func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { - c.t.Helper() - - select { - case <-c.linkEP.C: - c.t.Fatal(errMsg) - - case <-time.After(wait): - } -} - -// CheckNoPacket verifies that no packet is received for 1 second. -func (c *Context) CheckNoPacket(errMsg string) { - c.CheckNoPacketTimeout(errMsg, 1*time.Second) -} - -// GetPacket reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. It will fail with an error if no packet is received for -// 2 seconds. -func (c *Context) GetPacket() []byte { - c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -// GetPacketNonBlocking reads a packet from the link layer endpoint -// and verifies that it is an IPv4 packet with the expected source -// and destination address. If no packet is available it will return -// nil immediately. -func (c *Context) GetPacketNonBlocking() []byte { - c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - default: - return nil - } -} - -// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { - // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) - if len(buf) > maxTotalSize { - buf = buf[:maxTotalSize] - } - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) - icmp.SetType(typ) - icmp.SetCode(code) - const icmpv4VariableHeaderOffset = 4 - copy(icmp[icmpv4VariableHeaderOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset:], p2) - - // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -// BuildSegment builds a TCP segment based on the given Headers and payload. -func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { - return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) -} - -// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, -// payload and source and destination IPv4 addresses. -func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv4MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: uint8(h.Flags), - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - return buf.ToVectorisedView() -} - -// SendSegment sends a TCP segment that has already been built and written to a -// buffer.VectorisedView. -func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: s, - }) -} - -// SendPacket builds and sends a TCP segment(with the provided payload & TCP -// headers) in an IPv4 packet via the link layer endpoint. -func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: c.BuildSegment(payload, h), - }) -} - -// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload -// & TCPheaders) in an IPv4 packet via the link layer endpoint using the -// provided source and destination IPv4 addresses. -func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: c.BuildSegmentWithAddrs(payload, h, src, dst), - }) -} - -// SendAck sends an ACK packet. -func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { - c.SendAckWithSACK(seq, bytesReceived, nil) -} - -// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified. -func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) { - options := make([]byte, 40) - offset := 0 - if len(sackBlocks) > 0 { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) - } - - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options[:offset], - }) -} - -// ReceiveAndCheckPacket reads a packet from the link layer endpoint and -// verifies that the packet packet payload of packet matches the slice -// of data indicated by offset & size. -func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { - c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) -} - -// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size and skips optlen bytes in addition to the IP -// TCP headers when comparing the data. -func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize+optlen), - checker.TCP( - checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } -} - -// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size. It returns true if a packet was received and -// processed. -func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { - b := c.GetPacketNonBlocking() - if b == nil { - return false - } - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } - return true -} - -// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only -// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 -// only endpoint instead of a default dual stack socket. -func (c *Context) CreateV6Endpoint(v6only bool) { - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.EP.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } -} - -// GetV6Packet reads a single packet from the link layer endpoint of the context -// and asserts that it is an IPv6 Packet with the expected src/dest addresses. -func (c *Context) GetV6Packet() []byte { - c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of -// the context. -func (c *Context) SendV6Packet(payload []byte, h *Headers) { - c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) -} - -// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer -// endpoint of the context using the provided source and destination IPv6 -// addresses. -func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - NextHeader: uint8(tcp.ProtocolNumber), - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, - }) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv6MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: header.TCPMinimumSize, - Flags: uint8(h.Flags), - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) -} - -// CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { - c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) -} - -// Connect performs the 3-way handshake for c.EP with the provided Initial -// Sequence Number (iss) and receive window(rcvWnd) and any options if -// specified. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -// -// PreCondition: c.EP must already be created. -func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { - c.t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - c.SendPacket(nil, &Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: options, - }) - - // Receive ACK packet. - checker.IPv4(c.t, c.GetPacket(), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.Port = tcpHdr.SourcePort() -} - -// Create creates a TCP endpoint. -func (c *Context) Create(epRcvBuf int) { - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if epRcvBuf != -1 { - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - } -} - -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { - c.Create(epRcvBuf) - c.Connect(iss, rcvWnd, options) -} - -// RawEndpoint is just a small wrapper around a TCP endpoint's state to make -// sending data and ACK packets easy while being able to manipulate the sequence -// numbers and timestamp values as needed. -type RawEndpoint struct { - C *Context - SrcPort uint16 - DstPort uint16 - Flags int - NextSeqNum seqnum.Value - AckNum seqnum.Value - WndSize seqnum.Size - RecentTS uint32 // Stores the latest timestamp to echo back. - TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. - - // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. - SACKPermitted bool -} - -// SendPacketWithTS embeds the provided tsVal in the Timestamp option -// for the packet to be sent out. -func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { - r.TSVal = tsVal - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) - r.SendPacket(payload, tsOpt[:]) -} - -// SendPacket is a small wrapper function to build and send packets. -func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { - packetHeaders := &Headers{ - SrcPort: r.SrcPort, - DstPort: r.DstPort, - Flags: r.Flags, - SeqNum: r.NextSeqNum, - AckNum: r.AckNum, - RcvWnd: r.WndSize, - TCPOpts: opts, - } - r.C.SendPacket(payload, packetHeaders) - r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) -} - -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { - // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.TCPTimestampChecker(true, 0, tsVal), - ), - ) - // Store the parsed TSVal from the ack as recentTS. - tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) - opts := tcpSeg.ParsedOptions() - r.RecentTS = opts.TSVal -} - -// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK -// matches the provided rcvWnd. -func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.Window(rcvWnd), - ), - ) -} - -// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. -func (r *RawEndpoint) VerifyACKNoSACK() { - r.VerifyACKHasSACK(nil) -} - -// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. -func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { - // Read ACK and verify that the TCP options in the segment do - // not contain a SACK block. - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.TCPSACKBlockChecker(sackBlocks), - ), - ) -} - -// CreateConnectedWithOptions creates and connects c.ep with the specified TCP -// options enabled and returns a RawEndpoint which represents the other end of -// the connection. -// -// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK -// does not carry an option that was not requested. -func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} - err = c.EP.Connect(testFullAddr) - if err != tcpip.ErrConnectStarted { - c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) - } - // Receive SYN packet. - b := c.GetPacket() - // Validate that the syn has the timestamp option and a valid - // TS value. - mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) - - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{ - MSS: mss, - TS: true, - WS: int(c.WindowScale), - SACKPermitted: c.SACKEnabled(), - }), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpSeg := header.TCP(header.IPv4(b).Payload()) - synOptions := header.ParseSynOptions(tcpSeg.Options(), false) - - // Build options w/ tsVal to be sent in the SYN-ACK. - synAckOptions := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - if wantOptions.WS != -1 { - offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:]) - } - if wantOptions.TS { - offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) - } - if wantOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) - } - - offset += header.AddTCPOptionPadding(synAckOptions, offset) - - // Build SYN-ACK. - c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) - iss := seqnum.Value(testInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - TCPOpts: synAckOptions[:offset], - }) - - // Read ACK. - ackPacket := c.GetPacket() - - // Verify TCP header fields. - tcpCheckers := []checker.TransportChecker{ - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS) + 1), - checker.AckNum(uint32(iss) + 1), - } - - // Verify that tsEcr of ACK packet is wantOptions.TSVal if the - // timestamp option was enabled, if not then we verify that - // there is no timestamp in the ACK packet. - if wantOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) - - ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) - ackOptions := ackSeg.ParsedOptions() - - // Wait for connection to be established. - select { - case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Store the source port in use by the endpoint. - c.Port = tcpSeg.SourcePort() - - // Mark in context that timestamp option is enabled for this endpoint. - c.TimeStampEnabled = true - - return &RawEndpoint{ - C: c, - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagAck | header.TCPFlagPsh, - NextSeqNum: iss + 1, - AckNum: c.IRS.Add(1), - WndSize: 30000, - RecentTS: ackOptions.TSVal, - TSVal: wantOptions.TSVal, - SACKPermitted: wantOptions.SACKPermitted, - } -} - -// AcceptWithOptions initializes a listening endpoint and connects to it with the -// provided options enabled. It also verifies that the SYN-ACK has the expected -// values for the provided options. -// -// The function returns a RawEndpoint representing the other end of the accepted -// endpoint. -func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if err := ep.Listen(10); err != nil { - c.t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - c.t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - return rep -} - -// PassiveConnect just disables WindowScaling and delegates the call to -// PassiveConnectWithOptions. -func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { - synOptions.WS = -1 - c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) -} - -// PassiveConnectWithOptions initiates a new connection (with the specified TCP -// options enabled) to the port on which the Context.ep is listening for new -// connections. It also validates that the SYN-ACK has the expected values for -// the enabled options. -// -// NOTE: MSS is not a negotiated option and it can be asymmetric -// in each direction. This function uses the maxPayload to set the MSS to be -// sent to the peer on a connect and validates that the MSS in the SYN-ACK -// response is equal to the MTU - (tcphdr len + iphdr len). -// -// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the -// value of the window scaling option to be sent in the SYN. If synOptions.WS > -// 0 then we send the WindowScale option. -func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - opts := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - offset += header.EncodeMSSOption(uint32(maxPayload), opts) - - if synOptions.WS >= 0 { - offset += header.EncodeWSOption(3, opts[offset:]) - } - if synOptions.TS { - offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) - } - - if synOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(opts[offset:]) - } - - paddingToAdd := 4 - offset%4 - // Now add any padding bytes that might be required to quad align the - // options. - for i := offset; i < offset+paddingToAdd; i++ { - opts[i] = header.TCPOptionNOP - } - offset += paddingToAdd - - // Send a SYN request. - iss := seqnum.Value(testInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - TCPOpts: opts[:offset], - }) - - // Receive the SYN-ACK reply. Make sure MSS and other expected options - // are present. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(StackPort), - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(iss) + 1), - checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), - } - - // If TS option was enabled in the original SYN then add a checker to - // validate the Timestamp option in the SYN-ACK. - if synOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) - rcvWnd := seqnum.Size(30000) - ackHeaders := &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: rcvWnd, - } - - // If WS was expected to be in effect then scale the advertised window - // correspondingly. - if synOptions.WS > 0 { - ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) - } - - parsedOpts := tcp.ParsedOptions() - if synOptions.TS { - // Echo the tsVal back to the peer in the tsEcr field of the - // timestamp option. - // Increment TSVal by 1 from the value sent in the SYN and echo - // the TSVal in the SYN-ACK in the TSEcr field. - opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) - ackHeaders.TCPOpts = opts[:] - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.Port = StackPort - - return &RawEndpoint{ - C: c, - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagPsh | header.TCPFlagAck, - NextSeqNum: iss + 1, - AckNum: c.IRS + 1, - WndSize: rcvWnd, - SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), - RecentTS: parsedOpts.TSVal, - TSVal: synOptions.TSVal + 1, - } -} - -// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true -// for the Stack in the context. -func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { - // Stack doesn't support SACK. So just return. - return false - } - return bool(v) -} - -// SetGSOEnabled enables or disables generic segmentation offload. -func (c *Context) SetGSOEnabled(enable bool) { - c.linkEP.GSO = enable -} - -// MSSWithoutOptions returns the value for the MSS used by the stack when no -// options are in use. -func (c *Context) MSSWithoutOptions() uint16 { - return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) -} - -// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no -// options are in use for IPv6 packets. -func (c *Context) MSSWithoutOptionsV6() uint16 { - return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD deleted file mode 100644 index 43fcc27f0..000000000 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tcpconntrack", - srcs = ["tcp_conntrack.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcpconntrack_test", - size = "small", - srcs = ["tcp_conntrack_test.go"], - deps = [ - ":tcpconntrack", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go deleted file mode 100644 index 93712cd45..000000000 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tcpconntrack implements a TCP connection tracking object. It allows -// users with access to a segment stream to figure out when a connection is -// established, reset, and closed (and in the last case, who closed first). -package tcpconntrack - -import ( - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -// Result is returned when the state of a TCB is updated in response to an -// inbound or outbound segment. -type Result int - -const ( - // ResultDrop indicates that the segment should be dropped. - ResultDrop Result = iota - - // ResultConnecting indicates that the connection remains in a - // connecting state. - ResultConnecting - - // ResultAlive indicates that the connection remains alive (connected). - ResultAlive - - // ResultReset indicates that the connection was reset. - ResultReset - - // ResultClosedByPeer indicates that the connection was gracefully - // closed, and the inbound stream was closed first. - ResultClosedByPeer - - // ResultClosedBySelf indicates that the connection was gracefully - // closed, and the outbound stream was closed first. - ResultClosedBySelf -) - -// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP -// connection and inform the caller when the connection has been closed. -type TCB struct { - inbound stream - outbound stream - - // State handlers. - handlerInbound func(*TCB, header.TCP) Result - handlerOutbound func(*TCB, header.TCP) Result - - // firstFin holds a pointer to the first stream to send a FIN. - firstFin *stream - - // state is the current state of the stream. - state Result -} - -// Init initializes the state of the TCB according to the initial SYN. -func (t *TCB) Init(initialSyn header.TCP) Result { - t.handlerInbound = synSentStateInbound - t.handlerOutbound = synSentStateOutbound - - iss := seqnum.Value(initialSyn.SequenceNumber()) - t.outbound.una = iss - t.outbound.nxt = iss.Add(logicalLen(initialSyn)) - t.outbound.end = t.outbound.nxt - - // Even though "end" is a sequence number, we don't know the initial - // receive sequence number yet, so we store the window size until we get - // a SYN from the peer. - t.inbound.una = 0 - t.inbound.nxt = 0 - t.inbound.end = seqnum.Value(initialSyn.WindowSize()) - t.state = ResultConnecting - return t.state -} - -// UpdateStateInbound updates the state of the TCB based on the supplied inbound -// segment. -func (t *TCB) UpdateStateInbound(tcp header.TCP) Result { - st := t.handlerInbound(t, tcp) - if st != ResultDrop { - t.state = st - } - return st -} - -// UpdateStateOutbound updates the state of the TCB based on the supplied -// outbound segment. -func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { - st := t.handlerOutbound(t, tcp) - if st != ResultDrop { - t.state = st - } - return st -} - -// IsAlive returns true as long as the connection is established(Alive) -// or connecting state. -func (t *TCB) IsAlive() bool { - return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed()) -} - -// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream. -func (t *TCB) OutboundSendSequenceNumber() seqnum.Value { - return t.outbound.nxt -} - -// InboundSendSequenceNumber returns the snd.NXT for the inbound stream. -func (t *TCB) InboundSendSequenceNumber() seqnum.Value { - return t.inbound.nxt -} - -// adapResult modifies the supplied "Result" according to the state of the TCB; -// if r is anything other than "Alive", or if one of the streams isn't closed -// yet, it is returned unmodified. Otherwise it's converted to either -// ClosedBySelf or ClosedByPeer depending on which stream was closed first. -func (t *TCB) adaptResult(r Result) Result { - // Check the unmodified case. - if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() { - return r - } - - // Find out which was closed first. - if t.firstFin == &t.outbound { - return ResultClosedBySelf - } - - return ResultClosedByPeer -} - -// synSentStateInbound is the state handler for inbound segments when the -// connection is in SYN-SENT state. -func synSentStateInbound(t *TCB, tcp header.TCP) Result { - flags := tcp.Flags() - ackPresent := flags&header.TCPFlagAck != 0 - ack := seqnum.Value(tcp.AckNumber()) - - // Ignore segment if ack is present but not acceptable. - if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) { - return ResultConnecting - } - - // If reset is specified, we will let the packet through no matter what - // but we will also destroy the connection if the ACK is present (and - // implicitly acceptable). - if flags&header.TCPFlagRst != 0 { - if ackPresent { - t.inbound.rstSeen = true - return ResultReset - } - return ResultConnecting - } - - // Ignore segment if SYN is not set. - if flags&header.TCPFlagSyn == 0 { - return ResultConnecting - } - - // Update state informed by this SYN. - irs := seqnum.Value(tcp.SequenceNumber()) - t.inbound.una = irs - t.inbound.nxt = irs.Add(logicalLen(tcp)) - t.inbound.end += irs - - t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize())) - - // If the ACK was set (it is acceptable), update our unacknowledgement - // tracking. - if ackPresent { - // Advance the "una" and "end" indices of the outbound stream. - if t.outbound.una.LessThan(ack) { - t.outbound.una = ack - } - - if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) { - t.outbound.end = end - } - } - - // Update handlers so that new calls will be handled by new state. - t.handlerInbound = allOtherInbound - t.handlerOutbound = allOtherOutbound - - return ResultAlive -} - -// synSentStateOutbound is the state handler for outbound segments when the -// connection is in SYN-SENT state. -func synSentStateOutbound(t *TCB, tcp header.TCP) Result { - // Drop outbound segments that aren't retransmits of the original one. - if tcp.Flags() != header.TCPFlagSyn || - tcp.SequenceNumber() != uint32(t.outbound.una) { - return ResultDrop - } - - // Update the receive window. We only remember the largest value seen. - if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end { - t.inbound.end = wnd - } - - return ResultConnecting -} - -// update updates the state of inbound and outbound streams, given the supplied -// inbound segment. For outbound segments, this same function can be called with -// swapped inbound/outbound streams. -func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result { - // Ignore segments out of the window. - s := seqnum.Value(tcp.SequenceNumber()) - if !inbound.acceptable(s, dataLen(tcp)) { - return ResultAlive - } - - flags := tcp.Flags() - if flags&header.TCPFlagRst != 0 { - inbound.rstSeen = true - return ResultReset - } - - // Ignore segments that don't have the ACK flag, and those with the SYN - // flag. - if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 { - return ResultAlive - } - - // Ignore segments that acknowledge not yet sent data. - ack := seqnum.Value(tcp.AckNumber()) - if outbound.nxt.LessThan(ack) { - return ResultAlive - } - - // Advance the "una" and "end" indices of the outbound stream. - if outbound.una.LessThan(ack) { - outbound.una = ack - } - - if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) { - outbound.end = end - } - - // Advance the "nxt" index of the inbound stream. - end := s.Add(logicalLen(tcp)) - if inbound.nxt.LessThan(end) { - inbound.nxt = end - } - - // Note the index of the FIN segment. And stash away a pointer to the - // first stream to see a FIN. - if flags&header.TCPFlagFin != 0 && !inbound.finSeen { - inbound.finSeen = true - inbound.fin = end - 1 - - if *firstFin == nil { - *firstFin = inbound - } - } - - return ResultAlive -} - -// allOtherInbound is the state handler for inbound segments in all states -// except SYN-SENT. -func allOtherInbound(t *TCB, tcp header.TCP) Result { - return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin)) -} - -// allOtherOutbound is the state handler for outbound segments in all states -// except SYN-SENT. -func allOtherOutbound(t *TCB, tcp header.TCP) Result { - return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin)) -} - -// streams holds the state of a TCP unidirectional stream. -type stream struct { - // The interval [una, end) is the allowed interval as defined by the - // receiver, i.e., anything less than una has already been acknowledged - // and anything greater than or equal to end is beyond the receiver - // window. The interval [una, nxt) is the acknowledgable range, whose - // right edge indicates the sequence number of the next byte to be sent - // by the sender, i.e., anything greater than or equal to nxt hasn't - // been sent yet. - una seqnum.Value - nxt seqnum.Value - end seqnum.Value - - // finSeen indicates if a FIN has already been sent on this stream. - finSeen bool - - // fin is the sequence number of the FIN. It is only valid after finSeen - // is set to true. - fin seqnum.Value - - // rstSeen indicates if a RST has already been sent on this stream. - rstSeen bool -} - -// acceptable determines if the segment with the given sequence number and data -// length is acceptable, i.e., if it's within the [una, end) window or, in case -// the window is zero, if it's a packet with no payload and sequence number -// equal to una. -func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - wnd := s.una.Size(s.end) - if wnd == 0 { - return segLen == 0 && segSeq == s.una - } - - // Make sure [segSeq, seqSeq+segLen) is non-empty. - if segLen == 0 { - segLen = 1 - } - - return seqnum.Overlap(s.una, wnd, segSeq, segLen) -} - -// closed determines if the stream has already been closed. This happens when -// a FIN has been set by the sender and acknowledged by the receiver. -func (s *stream) closed() bool { - return s.finSeen && s.fin.LessThan(s.una) -} - -// dataLen returns the length of the TCP segment payload. -func dataLen(tcp header.TCP) seqnum.Size { - return seqnum.Size(len(tcp) - int(tcp.DataOffset())) -} - -// logicalLen calculates the logical length of the TCP segment. -func logicalLen(tcp header.TCP) seqnum.Size { - l := dataLen(tcp) - flags := tcp.Flags() - if flags&header.TCPFlagSyn != 0 { - l++ - } - if flags&header.TCPFlagFin != 0 { - l++ - } - return l -} diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go deleted file mode 100644 index 5e271b7ca..000000000 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpconntrack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" -) - -// connected creates a connection tracker TCB and sets it to a connected state -// by performing a 3-way handshake. -func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: iss, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: irw, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: irs, - AckNum: iss + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: isw, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: iss + 1, - AckNum: irs + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: irw, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - return &tcb -} - -func TestConnectionRefused(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive RST. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionRefusedInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionResetInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestRetransmitOnSynSent(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Retransmit the same SYN. - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) - } -} - -func TestRetransmitOnSynRcvd(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. This will cause the state to go to SYN-RCVD. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Retransmit the original SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Transmit a SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} - -func TestClosedBySelf(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Send FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1236, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestClosedByPeer(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Receive FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 791, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) - } -} - -func TestSendAndReceiveDataClosedBySelf(t *testing.T) { - sseq := uint32(1234) - rseq := uint32(789) - tcb := connected(t, sseq, rseq, 30000, 50000) - sseq++ - rseq++ - - // Send some data. - tcp := make(header.TCP, header.TCPMinimumSize+1024) - - for i := uint32(0); i < 10; i++ { - // Send some data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - sseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - for i := uint32(0); i < 10; i++ { - // Receive some data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - rseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - // Send FIN. - tcp = tcp[:header.TCPMinimumSize] - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - sseq++ - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - rseq++ - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestIgnoreBadResetOnSynSent(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive a RST with a bad ACK, it should not cause the connection to - // be reset. - acks := []uint32{1234, 1236, 1000, 5000} - flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} - for _, a := range acks { - for _, f := range flags { - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: a, - DataOffset: header.TCPMinimumSize, - Flags: f, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - } - - // Complete the handshake. - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD deleted file mode 100644 index 97e4d5825..000000000 --- a/pkg/tcpip/transport/udp/BUILD +++ /dev/null @@ -1,62 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_template_instance( - name = "udp_packet_list", - out = "udp_packet_list.go", - package = "udp", - prefix = "udpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*udpPacket", - "Linker": "*udpPacket", - }, -) - -go_library( - name = "udp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "udp_packet_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - ], -) - -go_test( - name = "udp_x_test", - size = "small", - srcs = ["udp_test.go"], - deps = [ - ":udp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go new file mode 100755 index 000000000..673a9373b --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,173 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *udpPacketList) PushFront(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(l.head) + udpPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *udpPacketList) PushBack(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(nil) + udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + a := udpPacketElementMapper{}.linkerFor(b).Next() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + b := udpPacketElementMapper{}.linkerFor(a).Prev() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *udpPacketList) Remove(e *udpPacket) { + prev := udpPacketElementMapper{}.linkerFor(e).Prev() + next := udpPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go new file mode 100755 index 000000000..5c4b43214 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -0,0 +1,134 @@ +// automatically generated by stateify. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *udpPacket) beforeSave() {} +func (x *udpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("udpPacketEntry", &x.udpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("timestamp", &x.timestamp) +} + +func (x *udpPacket) afterLoad() {} +func (x *udpPacket) load(m state.Map) { + m.Load("udpPacketEntry", &x.udpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("timestamp", &x.timestamp) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("uniqueID", &x.uniqueID) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("state", &x.state) + m.Save("dstPort", &x.dstPort) + m.Save("v6only", &x.v6only) + m.Save("ttl", &x.ttl) + m.Save("multicastTTL", &x.multicastTTL) + m.Save("multicastAddr", &x.multicastAddr) + m.Save("multicastNICID", &x.multicastNICID) + m.Save("multicastLoop", &x.multicastLoop) + m.Save("reusePort", &x.reusePort) + m.Save("bindToDevice", &x.bindToDevice) + m.Save("broadcast", &x.broadcast) + m.Save("boundBindToDevice", &x.boundBindToDevice) + m.Save("boundPortFlags", &x.boundPortFlags) + m.Save("sendTOS", &x.sendTOS) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("multicastMemberships", &x.multicastMemberships) + m.Save("effectiveNetProtos", &x.effectiveNetProtos) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("uniqueID", &x.uniqueID) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("state", &x.state) + m.Load("dstPort", &x.dstPort) + m.Load("v6only", &x.v6only) + m.Load("ttl", &x.ttl) + m.Load("multicastTTL", &x.multicastTTL) + m.Load("multicastAddr", &x.multicastAddr) + m.Load("multicastNICID", &x.multicastNICID) + m.Load("multicastLoop", &x.multicastLoop) + m.Load("reusePort", &x.reusePort) + m.Load("bindToDevice", &x.bindToDevice) + m.Load("broadcast", &x.broadcast) + m.Load("boundBindToDevice", &x.boundBindToDevice) + m.Load("boundPortFlags", &x.boundPortFlags) + m.Load("sendTOS", &x.sendTOS) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("multicastMemberships", &x.multicastMemberships) + m.Load("effectiveNetProtos", &x.effectiveNetProtos) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *multicastMembership) beforeSave() {} +func (x *multicastMembership) save(m state.Map) { + x.beforeSave() + m.Save("nicID", &x.nicID) + m.Save("multicastAddr", &x.multicastAddr) +} + +func (x *multicastMembership) afterLoad() {} +func (x *multicastMembership) load(m state.Map) { + m.Load("nicID", &x.nicID) + m.Load("multicastAddr", &x.multicastAddr) +} + +func (x *udpPacketList) beforeSave() {} +func (x *udpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *udpPacketList) afterLoad() {} +func (x *udpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *udpPacketEntry) beforeSave() {} +func (x *udpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *udpPacketEntry) afterLoad() {} +func (x *udpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load}) + state.Register("udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load}) + state.Register("udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load}) + state.Register("udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go deleted file mode 100644 index 7051a7a9c..000000000 --- a/pkg/tcpip/transport/udp/udp_test.go +++ /dev/null @@ -1,1677 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_test - -import ( - "bytes" - "fmt" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// Addresses and ports used for testing. It is recommended that tests stick to -// using these addresses as it allows using the testFlow helper. -// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' -// represents the remote endpoint. -const ( - v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = v4MappedAddrPrefix + stackAddr - testV4MappedAddr = v4MappedAddrPrefix + testAddr - multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr - broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr - v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 - multicastAddr = "\xe8\x2b\xd3\xea" - multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - broadcastAddr = header.IPv4Broadcast - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in -// a packet header. These values are used to populate a header or verify one. -// Note that because they are used in packet headers, the addresses are never in -// a V4-mapped format. -type header4Tuple struct { - srcAddr tcpip.FullAddress - dstAddr tcpip.FullAddress -} - -// testFlow implements a helper type used for sending and receiving test -// packets. A given test flow value defines 1) the socket endpoint used for the -// test and 2) the type of packet send or received on the endpoint. E.g., a -// multicastV6Only flow is a V6 multicast packet passing through a V6-only -// endpoint. The type provides helper methods to characterize the flow (e.g., -// isV4) as well as return a proper header4Tuple for it. -type testFlow int - -const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket -) - -func (flow testFlow) String() string { - switch flow { - case unicastV4: - return "unicastV4" - case unicastV6: - return "unicastV6" - case unicastV6Only: - return "unicastV6Only" - case unicastV4in6: - return "unicastV4in6" - case multicastV4: - return "multicastV4" - case multicastV6: - return "multicastV6" - case multicastV6Only: - return "multicastV6Only" - case multicastV4in6: - return "multicastV4in6" - case broadcast: - return "broadcast" - case broadcastIn6: - return "broadcastIn6" - default: - return "unknown" - } -} - -// packetDirection explains if a flow is incoming (read) or outgoing (write). -type packetDirection int - -const ( - incoming packetDirection = iota - outgoing -) - -// header4Tuple returns the header4Tuple for the given flow and direction. Note -// that the tuple contains no mapped addresses as those only exist at the socket -// level but not at the packet header level. -func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { - var h header4Tuple - if flow.isV4() { - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastAddr - } else if flow.isBroadcast() { - h.dstAddr.Addr = broadcastAddr - } - } else { // IPv6 - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastV6Addr - } - } - return h -} - -func (flow testFlow) getMcastAddr() tcpip.Address { - if flow.isV4() { - return multicastAddr - } - return multicastV6Addr -} - -// mapAddrIfApplicable converts the given V4 address into its V4-mapped version -// if it is applicable to the flow. -func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { - if flow.isMapped() { - return v4MappedAddrPrefix + v4Addr - } - return v4Addr -} - -// netProto returns the protocol number used for the network packet. -func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { - if flow.isV4() { - return ipv4.ProtocolNumber - } - return ipv6.ProtocolNumber -} - -// sockProto returns the protocol number used when creating the socket -// endpoint for this flow. -func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { - switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: - return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: - return ipv4.ProtocolNumber - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { - if flow.isV4() { - return checker.IPv4 - } - return checker.IPv6 -} - -func (flow testFlow) isV6() bool { return !flow.isV4() } -func (flow testFlow) isV4() bool { - return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() -} - -func (flow testFlow) isV6Only() bool { - switch flow { - case unicastV6Only, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMulticast() bool { - switch flow { - case multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isBroadcast() bool { - switch flow { - case broadcast, broadcastIn6: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMapped() bool { - switch flow { - case unicastV4in6, multicastV4in6, broadcastIn6: - return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func newDualTestContext(t *testing.T, mtu uint32) *testContext { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - }) - ep := channel.New(256, mtu, "") - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEP: ep, - } -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { - c.t.Helper() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) - if err != nil { - c.t.Fatal("NewEndpoint failed: ", err) - } -} - -func (c *testContext) createEndpointForFlow(flow testFlow) { - c.t.Helper() - - c.createEndpoint(flow.sockProto()) - if flow.isV6Only() { - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - } else if flow.isBroadcast() { - if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { - c.t.Fatal("SetSockOpt failed:", err) - } - } -} - -// getPacketAndVerify reads a packet from the link endpoint and verifies the -// header against expected values from the given test flow. In addition, it -// calls any extra checker functions provided. -func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { - c.t.Helper() - - select { - case p := <-c.linkEP.C: - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - h := flow.header4Tuple(outgoing) - checkers := append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -// injectPacket creates a packet of the given flow and with the given payload, -// and injects it into the link endpoint. -func (c *testContext) injectPacket(flow testFlow, payload []byte) { - c.t.Helper() - - h := flow.header4Tuple(incoming) - if flow.isV4() { - c.injectV4Packet(payload, &h, true /* valid */) - } else { - c.injectV6Packet(payload, &h, true /* valid */) - } -} - -// injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - l := uint16(header.UDPMinimumSize + len(payload)) - if !valid { - // Change the UDP payload length to corrupt the header - // as requested by the caller. - l++ - } - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: l, - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), - }) -} - -// injectV4Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), - }) -} - -func newPayload() []byte { - return newMinPayload(30) -} - -func newMinPayload(minSize int) []byte { - b := make([]byte, minSize+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer ep.Close() - - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) - } - - // strPtr is used instead of taking the address of string literals, which is - // a compiler error. - strPtr := func(s string) *string { - return &s - } - - testActions := []struct { - name string - setBindToDevice *string - setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption - }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) - } - } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) - } - }) - } -} - -// testReadInternal sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then attempts to read it from the -// UDP endpoint and depending on if this was expected to succeed verifies its -// correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { - c.t.Helper() - - payload := newPayload() - c.injectPacket(flow, payload) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, _, err = c.ep.Read(&addr) - - case <-time.After(300 * time.Millisecond): - if packetShouldBeDropped { - return // expected to time out - } - c.t.Fatal("timed out waiting for data") - } - } - - if expectReadError && err != nil { - c.checkEndpointReadStats(1, epstats, err) - return - } - - if err != nil { - c.t.Fatal("Read failed:", err) - } - - if packetShouldBeDropped { - c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) - } - - // Check the peer address. - h := flow.header4Tuple(incoming) - if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("bad payload: got %x, want %x", v, payload) - } - c.checkEndpointReadStats(1, epstats, err) -} - -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { - c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) -} - -// testFailingRead sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then tries to read it from the UDP -// endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { - c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) -} - -func TestBindEphemeralPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } -} - -func TestBindReservedPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - addr, err := c.ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) - } - - // We can't bind the address reserved by the connected endpoint above. - { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) - } - } - - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - // We can't bind ipv4-any on the port reserved by the connected endpoint - // above, since the endpoint is dual-stack. - if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) - } - // We can bind an ipv4 address on this port, though. - if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - }() - - // Once the connected endpoint releases its port reservation, we are able to - // bind ipv4-any once again. - c.ep.Close() - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - }() -} - -func TestV4ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to local address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV6ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV6) -} - -func TestV4ReadOnV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4) -} - -// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast -// address and receive data sent to that address. -func TestReadOnBoundToMulticast(t *testing.T) { - // FIXME(b/128189410): multicastV4in6 currently doesn't work as - // AddMembershipOption doesn't handle V4in6 addresses. - for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to multicast address. - mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) - if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - // Join multicast group. - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatal("SetSockOpt failed:", err) - } - - // Check that we receive multicast packets but not unicast or broadcast - // ones. - testRead(c, flow) - testFailingRead(c, broadcast, false /* expectReadError */) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and can receive only broadcast data. -func TestV4ReadOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to broadcast address. - bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) - if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Check that we receive broadcast packets but not unicast ones. - testRead(c, flow) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY -// and receive broadcast and unicast data. -func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s (", err) - } - - // Check that we receive both broadcast and unicast packets. - testRead(c, flow) - testRead(c, unicastV4) - }) - } -} - -// testFailingWrite sends a packet of the given test flow into the UDP endpoint -// and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - - payload := buffer.View(newPayload()) - _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - }) - c.checkEndpointWriteStats(1, epstats, gotErr) - if gotErr != wantErr { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) - } -} - -// testWrite sends a packet of the given test flow from the UDP endpoint to the -// flow's destination address:port. It then receives it from the link endpoint -// and verifies its correctness including any additional checker functions -// provided. -func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteInternal(c, flow, true, checkers...) -} - -// testWriteWithoutDestination sends a packet of the given test flow from the -// UDP endpoint without giving a destination address:port. It then receives it -// from the link endpoint and verifies its correctness including any additional -// checker functions provided. -func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteInternal(c, flow, false, checkers...) -} - -func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - writeOpts := tcpip.WriteOptions{} - if setDest { - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - writeOpts = tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - } - } - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - c.checkEndpointWriteStats(1, epstats, err) - // Received the packet and check the payload. - b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP - if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) - } else { - udp = header.UDP(header.IPv6(b).Payload()) - } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } - - return udp.SourcePort() -} - -func testDualWrite(c *testContext) uint16 { - c.t.Helper() - - v4Port := testWrite(c, unicastV4in6) - v6Port := testWrite(c, unicastV6) - if v4Port != v6Port { - c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) - } - - return v4Port -} - -func TestDualWriteUnbound(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) -} - -func TestDualWriteBoundToWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - p := testDualWrite(c) - if p != stackPort { - c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) - } -} - -func TestDualWriteConnectedToV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testWrite(c, unicastV6) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) - const want = 1 - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { - c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) - } -} - -func TestDualWriteConnectedToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testWrite(c, unicastV4in6) - - // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) -} - -func TestV4WriteOnV6Only(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6Only) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) -} - -func TestV6WriteOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to v4 mapped address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) -} - -func TestV6WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - testWriteWithoutDestination(c, unicastV6) -} - -func TestV4WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - testWriteWithoutDestination(c, unicastV4) -} - -// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket -// that is bound to a V4 multicast address. -func TestWriteOnBoundToV4Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a -// socket that is bound to a V4-mapped multicast address. -func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV6, multicastV6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// V6-only socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToBroadcast checks that we can send packets out of a -// socket that is bound to the broadcast address. -func TestWriteOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 broadcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a -// socket that is bound to the V4-mapped broadcast address. -func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -func TestReadIncrementsPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - // Create IPv4 UDP endpoint - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testRead(c, unicastV4) - - var want uint64 = 1 - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { - c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) - } -} - -func TestWriteIncrementsPacketsSent(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) - - var want uint64 = 2 - if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { - c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) - } -} - -func TestTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - var wantTTL uint8 - if flow.isMulticast() { - wantTTL = multicastTTL - } else { - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } -} - -func TestSetTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - ep.Close() - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } - }) - } -} - -func TestTOSV4(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tos = 0xC0 - var v tcpip.IPv4TOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) - } - - if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) - } - - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - - if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) - } - - testWrite(c, flow, checker.TOS(tos, 0)) - }) - } -} - -func TestTOSV6(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tos = 0xC0 - var v tcpip.IPv6TrafficClassOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) - } - - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - c.t.Errorf("SetSockOpt failed: %s", err) - } - - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) - } - - testWrite(c, flow, checker.TOS(tos, 0)) - }) - } -} - -func TestMulticastInterfaceOption(t *testing.T) { - for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - h := flow.header4Tuple(outgoing) - mcastAddr := h.dstAddr.Addr - localIfAddr := h.srcAddr.Addr - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: flow.mapAddrIfApplicable(mcastAddr), - Port: stackPort, - } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - } - - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) - } -} - -// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV4UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true, will result in a payload large enough - // so that the final generated IPv4 packet is larger than - // header.IPv4MinimumProcessableDatagramSize. - largePayload bool - }{ - {unicastV4, true, false}, - {unicastV4, true, true}, - {multicastV4, false, false}, - {multicastV4, false, true}, - {broadcast, false, false}, - {broadcast, false, true}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(576) - } - c.injectPacket(tc.flow, payload) - if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: - t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return - } - } - - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize - } - - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") - } - }) - } -} - -// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV6UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true will result in a payload large enough to - // create an IPv6 packet > header.IPv6MinimumMTU bytes. - largePayload bool - }{ - {unicastV6, true, false}, - {unicastV6, true, true}, - {multicastV6, false, false}, - {multicastV6, false, true}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(1280) - } - c.injectPacket(tc.flow, payload) - if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: - t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return - } - } - - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") - } - }) - } -} - -// TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats get incremented. -func TestIncrementMalformedPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - payload := newPayload() - c.t.Helper() - h := unicastV6.header4Tuple(incoming) - c.injectV6Packet(payload, &h, false /* !valid */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) - } -} - -// TestShutdownRead verifies endpoint read shutdown and error -// stats increment on packet receive. -func TestShutdownRead(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - testFailingRead(c, unicastV6, true /* expectReadError */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { - t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) - } -} - -// TestShutdownWrite verifies endpoint write shutdown and error -// stats increment on packet write. -func TestShutdownWrite(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) -} - -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil: - want.PacketsSent.IncrementBy(incr) - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: - want.WriteErrors.InvalidArgs.IncrementBy(incr) - case tcpip.ErrClosedForSend: - want.WriteErrors.WriteClosed.IncrementBy(incr) - case tcpip.ErrInvalidEndpointState: - want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoLinkAddress: - want.SendErrors.NoLinkAddr.IncrementBy(incr) - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: - want.SendErrors.NoRoute.IncrementBy(incr) - default: - want.SendErrors.SendToNetworkFailed.IncrementBy(incr) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil, tcpip.ErrWouldBlock: - case tcpip.ErrClosedForReceive: - want.ReadErrors.ReadClosed.IncrementBy(incr) - default: - c.t.Errorf("Endpoint error missing stats update err %v", err) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} |