diff options
Diffstat (limited to 'pkg/tcpip')
192 files changed, 45105 insertions, 6803 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 3c2b2b5ea..454e07662 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,12 +7,12 @@ go_library( srcs = [ "tcpip.go", "time_unsafe.go", + "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip/buffer", - "//pkg/tcpip/iptables", "//pkg/waiter", ], ) @@ -22,5 +21,12 @@ go_test( name = "tcpip_test", size = "small", srcs = ["tcpip_test.go"], - embed = [":tcpip"], + library = ":tcpip", +) + +go_test( + name = "tcpip_x_test", + size = "small", + srcs = ["timer_test.go"], + deps = [":tcpip"], ) diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 78df5a0b1..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,14 +1,13 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "gonet", srcs = ["gonet.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", @@ -22,7 +21,7 @@ go_test( name = "gonet_test", size = "small", srcs = ["gonet_test.go"], - embed = [":gonet"], + library = ":gonet", deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index cd6ce930a..d82ed5205 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -20,9 +20,9 @@ import ( "errors" "io" "net" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -43,18 +43,28 @@ func (e *timeoutError) Error() string { return "i/o timeout" } func (e *timeoutError) Timeout() bool { return true } func (e *timeoutError) Temporary() bool { return true } -// A Listener is a wrapper around a tcpip endpoint that implements +// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements // net.Listener. -type Listener struct { +type TCPListener struct { stack *stack.Stack ep tcpip.Endpoint wq *waiter.Queue cancel chan struct{} } -// NewListener creates a new Listener. -func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) { - // Create TCP endpoint, bind it, then start listening. +// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint. +func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener { + return &TCPListener{ + stack: s, + ep: ep, + wq: wq, + cancel: make(chan struct{}), + } +} + +// ListenTCP creates a new TCPListener. +func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) { + // Create a TCP endpoint, bind it, then start listening. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) if err != nil { @@ -81,28 +91,23 @@ func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkPr } } - return &Listener{ - stack: s, - ep: ep, - wq: &wq, - cancel: make(chan struct{}), - }, nil + return NewTCPListener(s, &wq, ep), nil } // Close implements net.Listener.Close. -func (l *Listener) Close() error { +func (l *TCPListener) Close() error { l.ep.Close() return nil } // Shutdown stops the HTTP server. -func (l *Listener) Shutdown() { +func (l *TCPListener) Shutdown() { l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) close(l.cancel) // broadcast cancellation } // Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { +func (l *TCPListener) Addr() net.Addr { a, err := l.ep.GetLocalAddress() if err != nil { return nil @@ -208,9 +213,9 @@ func (d *deadlineTimer) SetDeadline(t time.Time) error { return nil } -// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn +// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn // interface. -type Conn struct { +type TCPConn struct { deadlineTimer wq *waiter.Queue @@ -228,9 +233,9 @@ type Conn struct { read buffer.View } -// NewConn creates a new Conn. -func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { - c := &Conn{ +// NewTCPConn creates a new TCPConn. +func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { + c := &TCPConn{ wq: wq, ep: ep, } @@ -239,7 +244,7 @@ func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { } // Accept implements net.Conn.Accept. -func (l *Listener) Accept() (net.Conn, error) { +func (l *TCPListener) Accept() (net.Conn, error) { n, wq, err := l.ep.Accept() if err == tcpip.ErrWouldBlock { @@ -272,7 +277,7 @@ func (l *Listener) Accept() (net.Conn, error) { } } - return NewConn(wq, n), nil + return NewTCPConn(wq, n), nil } type opErrorer interface { @@ -323,13 +328,18 @@ func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, a } // Read implements net.Conn.Read. -func (c *Conn) Read(b []byte) (int, error) { +func (c *TCPConn) Read(b []byte) (int, error) { c.readMu.Lock() defer c.readMu.Unlock() deadline := c.readCancel() numRead := 0 + defer func() { + if numRead != 0 { + c.ep.ModerateRecvBuf(numRead) + } + }() for numRead != len(b) { if len(c.read) == 0 { var err error @@ -352,7 +362,7 @@ func (c *Conn) Read(b []byte) (int, error) { } // Write implements net.Conn.Write. -func (c *Conn) Write(b []byte) (int, error) { +func (c *TCPConn) Write(b []byte) (int, error) { deadline := c.writeCancel() // Check if deadlineTimer has already expired. @@ -431,7 +441,7 @@ func (c *Conn) Write(b []byte) (int, error) { } // Close implements net.Conn.Close. -func (c *Conn) Close() error { +func (c *TCPConn) Close() error { c.ep.Close() return nil } @@ -440,7 +450,7 @@ func (c *Conn) Close() error { // should just use Close. // // A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. -func (c *Conn) CloseRead() error { +func (c *TCPConn) CloseRead() error { if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { return c.newOpError("close", errors.New(terr.String())) } @@ -451,7 +461,7 @@ func (c *Conn) CloseRead() error { // should just use Close. // // A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. -func (c *Conn) CloseWrite() error { +func (c *TCPConn) CloseWrite() error { if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { return c.newOpError("close", errors.New(terr.String())) } @@ -459,7 +469,7 @@ func (c *Conn) CloseWrite() error { } // LocalAddr implements net.Conn.LocalAddr. -func (c *Conn) LocalAddr() net.Addr { +func (c *TCPConn) LocalAddr() net.Addr { a, err := c.ep.GetLocalAddress() if err != nil { return nil @@ -468,7 +478,7 @@ func (c *Conn) LocalAddr() net.Addr { } // RemoteAddr implements net.Conn.RemoteAddr. -func (c *Conn) RemoteAddr() net.Addr { +func (c *TCPConn) RemoteAddr() net.Addr { a, err := c.ep.GetRemoteAddress() if err != nil { return nil @@ -476,7 +486,7 @@ func (c *Conn) RemoteAddr() net.Addr { return fullToTCPAddr(a) } -func (c *Conn) newOpError(op string, err error) *net.OpError { +func (c *TCPConn) newOpError(op string, err error) *net.OpError { return &net.OpError{ Op: op, Net: "tcp", @@ -494,14 +504,14 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} } -// DialTCP creates a new TCP Conn connected to the specified address. -func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { +// DialTCP creates a new TCPConn connected to the specified address. +func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { return DialContextTCP(context.Background(), s, addr, network) } -// DialContextTCP creates a new TCP Conn connected to the specified address +// DialContextTCP creates a new TCPConn connected to the specified address // with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { // Create TCP endpoint, then connect. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) @@ -543,12 +553,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, } } - return NewConn(&wq, ep), nil + return NewTCPConn(&wq, ep), nil } -// A PacketConn is a wrapper around a tcpip endpoint that implements -// net.PacketConn. -type PacketConn struct { +// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements +// net.Conn and net.PacketConn. +type UDPConn struct { deadlineTimer stack *stack.Stack @@ -556,12 +566,23 @@ type PacketConn struct { wq *waiter.Queue } -// DialUDP creates a new PacketConn. +// NewUDPConn creates a new UDPConn. +func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn { + c := &UDPConn{ + stack: s, + ep: ep, + wq: wq, + } + c.deadlineTimer.init() + return c +} + +// DialUDP creates a new UDPConn. // // If laddr is nil, a local address is automatically chosen. // -// If raddr is nil, the PacketConn is left unconnected. -func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { +// If raddr is nil, the UDPConn is left unconnected. +func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) { var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) if err != nil { @@ -580,12 +601,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw } } - c := PacketConn{ - stack: s, - ep: ep, - wq: &wq, - } - c.deadlineTimer.init() + c := NewUDPConn(s, &wq, ep) if raddr != nil { if err := c.ep.Connect(*raddr); err != nil { @@ -599,14 +615,14 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw } } - return &c, nil + return c, nil } -func (c *PacketConn) newOpError(op string, err error) *net.OpError { +func (c *UDPConn) newOpError(op string, err error) *net.OpError { return c.newRemoteOpError(op, nil, err) } -func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { +func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { return &net.OpError{ Op: op, Net: "udp", @@ -617,22 +633,22 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne } // RemoteAddr implements net.Conn.RemoteAddr. -func (c *PacketConn) RemoteAddr() net.Addr { +func (c *UDPConn) RemoteAddr() net.Addr { a, err := c.ep.GetRemoteAddress() if err != nil { return nil } - return fullToTCPAddr(a) + return fullToUDPAddr(a) } // Read implements net.Conn.Read -func (c *PacketConn) Read(b []byte) (int, error) { +func (c *UDPConn) Read(b []byte) (int, error) { bytesRead, _, err := c.ReadFrom(b) return bytesRead, err } // ReadFrom implements net.PacketConn.ReadFrom. -func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { +func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { deadline := c.readCancel() var addr tcpip.FullAddress @@ -644,12 +660,12 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { return copy(b, read), fullToUDPAddr(addr), nil } -func (c *PacketConn) Write(b []byte) (int, error) { +func (c *UDPConn) Write(b []byte) (int, error) { return c.WriteTo(b, nil) } // WriteTo implements net.PacketConn.WriteTo. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { deadline := c.writeCancel() // Check if deadline has already expired. @@ -707,13 +723,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // Close implements net.PacketConn.Close. -func (c *PacketConn) Close() error { +func (c *UDPConn) Close() error { c.ep.Close() return nil } // LocalAddr implements net.PacketConn.LocalAddr. -func (c *PacketConn) LocalAddr() net.Addr { +func (c *UDPConn) LocalAddr() net.Addr { a, err := c.ep.GetLocalAddress() if err != nil { return nil diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 8ced960bb..3c552988a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -41,7 +41,7 @@ const ( ) func TestTimeouts(t *testing.T) { - nc := NewConn(nil, nil) + nc := NewTCPConn(nil, nil) dlfs := []struct { name string f func(time.Time) error @@ -127,12 +127,16 @@ func TestCloseReader(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - l, e := NewListener(s, addr, ipv4.ProtocolNumber) + l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { t.Fatalf("NewListener() = %v", e) } @@ -151,10 +155,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -170,13 +172,17 @@ func TestCloseReader(t *testing.T) { sender.close() } -// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when +// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when // using tcp.Forwarder. func TestCloseReaderWithForwarder(t *testing.T) { s, err := newLoopbackStack() if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -194,7 +200,7 @@ func TestCloseReaderWithForwarder(t *testing.T) { defer ep.Close() r.Complete(false) - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) // Give c.Read() a chance to block before closing the connection. time.AfterFunc(time.Millisecond*50, func() { @@ -203,10 +209,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -229,30 +233,21 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + _, err := r.CreateEndpoint(&wq) if err != nil { t.Fatalf("r.CreateEndpoint() = %v", err) } - defer ep.Close() - r.Complete(false) - - c := NewConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } + // Endpoint will be closed in deferred s.Close (above). }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -261,7 +256,7 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("connect() = %v", terr) } - c := NewConn(tc.wq, tc.ep) + c := NewTCPConn(tc.wq, tc.ep) if err := c.CloseRead(); err != nil { t.Errorf("c.CloseRead() = %v", err) @@ -282,6 +277,10 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -295,7 +294,7 @@ func TestCloseWrite(t *testing.T) { defer ep.Close() r.Complete(false) - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) n, e := c.Read(make([]byte, 256)) if n != 0 || e != io.EOF { @@ -313,7 +312,7 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("connect() = %v", terr) } - c := NewConn(tc.wq, tc.ep) + c := NewTCPConn(tc.wq, tc.ep) if err := c.CloseWrite(); err != nil { t.Errorf("c.CloseWrite() = %v", err) @@ -338,6 +337,10 @@ func TestUDPForwarder(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -357,7 +360,7 @@ func TestUDPForwarder(t *testing.T) { } defer ep.Close() - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) buf := make([]byte, 256) n, e := c.Read(buf) @@ -395,12 +398,16 @@ func TestDeadlineChange(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - l, e := NewListener(s, addr, ipv4.ProtocolNumber) + l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { t.Fatalf("NewListener() = %v", e) } @@ -444,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -496,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -545,7 +560,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { addr := tcpip.FullAddress{NICID, ip, 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - l, err := NewListener(s, addr, ipv4.ProtocolNumber) + l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) if err != nil { return nil, nil, nil, fmt.Errorf("NewListener: %v", err) } @@ -566,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { stop = func() { c1.Close() c2.Close() + s.Close() + s.Wait() } if err := l.Close(); err != nil { @@ -628,6 +645,10 @@ func TestTCPDialError(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -645,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -663,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index a7bf0c4dc..5e135c50d 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "prependable.go", "view.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer", visibility = ["//visibility:public"], ) @@ -20,5 +18,5 @@ go_test( "prependable_test.go", "view_test.go", ], - embed = [":buffer"], + library = ":buffer", ) diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 2f9a23d61..57d1922ab 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -42,7 +42,7 @@ func NewPrependableFromView(v View, extraCap int) Prependable { if extraCap == 0 { return Prependable{buf: v, usedIdx: 0} } - buf := make([]byte, extraCap, extraCap + len(v)) + buf := make([]byte, extraCap, extraCap+len(v)) buf = append(buf, v...) return Prependable{buf: buf, usedIdx: extraCap} } @@ -83,3 +83,9 @@ func (p *Prependable) Prepend(size int) []byte { p.usedIdx -= size return p.View()[:size:size] } + +// DeepCopy copies p and the bytes backing it. +func (p Prependable) DeepCopy() Prependable { + p.buf = append(View(nil), p.buf...) + return p +} diff --git a/pkg/tcpip/buffer/prependable_test.go b/pkg/tcpip/buffer/prependable_test.go index 43660c307..435a94a61 100644 --- a/pkg/tcpip/buffer/prependable_test.go +++ b/pkg/tcpip/buffer/prependable_test.go @@ -45,6 +45,6 @@ func TestNewPrependableFromView(t *testing.T) { if !reflect.DeepEqual(prep, testCase.want) { t.Errorf("NewPrependableFromView(%#v, %d) = %#v; want %#v", testCase.view, testCase.extraSize, prep, testCase.want) } - } ) + }) } } diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 150310c11..ea0c5413d 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -15,6 +15,11 @@ // Package buffer provides the implementation of a buffer view. package buffer +import ( + "bytes" + "io" +) + // View is a slice of a buffer, with convenience methods. type View []byte @@ -45,11 +50,31 @@ func (v *View) CapLength(length int) { *v = (*v)[:length:length] } +// Reader returns a bytes.Reader for v. +func (v *View) Reader() bytes.Reader { + var r bytes.Reader + r.Reset(*v) + return r +} + // ToVectorisedView returns a VectorisedView containing the receiver. func (v View) ToVectorisedView() VectorisedView { + if len(v) == 0 { + return VectorisedView{} + } return NewVectorisedView(len(v), []View{v}) } +// IsEmpty returns whether v is of length zero. +func (v View) IsEmpty() bool { + return len(v) == 0 +} + +// Size returns the length of v. +func (v View) Size() int { + return len(v) +} + // VectorisedView is a vectorised version of View using non contiguous memory. // It supports all the convenience methods supported by View. // @@ -65,7 +90,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -74,8 +100,49 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() + } +} + +// Read implements io.Reader. +func (vv *VectorisedView) Read(v View) (copied int, err error) { + count := len(v) + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + copy(v[copied:], vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return copied, nil + } + count -= len(vv.views[0]) + copy(v[copied:], vv.views[0]) + copied += len(vv.views[0]) + vv.removeFirst() + } + if copied == 0 { + return 0, io.EOF + } + return copied, nil +} + +// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It +// returns the number of bytes copied. +func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) { + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + dstVV.AppendView(vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return + } + count -= len(vv.views[0]) + dstVV.AppendView(vv.views[0]) + copied += len(vv.views[0]) + vv.removeFirst() } + return copied } // CapLength irreversibly reduces the length of the vectorised view. @@ -105,29 +172,45 @@ func (vv *VectorisedView) CapLength(length int) { // Clone returns a clone of this VectorisedView. // If the buffer argument is large enough to contain all the Views of this VectorisedView, // the method will avoid allocations and use the buffer to store the Views of the clone. -func (vv VectorisedView) Clone(buffer []View) VectorisedView { +func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. -func (vv VectorisedView) Size() int { +func (vv *VectorisedView) Size() int { return vv.size } @@ -135,7 +218,7 @@ func (vv VectorisedView) Size() int { // // If the vectorised view contains a single view, that view will be returned // directly. -func (vv VectorisedView) ToView() View { +func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } @@ -147,7 +230,7 @@ func (vv VectorisedView) ToView() View { } // Views returns the slice containing the all views. -func (vv VectorisedView) Views() []View { +func (vv *VectorisedView) Views() []View { return vv.views } @@ -156,3 +239,28 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) { vv.views = append(vv.views, vv2.views...) vv.size += vv2.size } + +// AppendView appends the given view into this vectorised view. +func (vv *VectorisedView) AppendView(v View) { + if len(v) == 0 { + return + } + vv.views = append(vv.views, v) + vv.size += len(v) +} + +// Readers returns a bytes.Reader for each of vv's views. +func (vv *VectorisedView) Readers() []bytes.Reader { + readers := make([]bytes.Reader, 0, len(vv.views)) + for _, v := range vv.views { + readers = append(readers, v.Reader()) + } + return readers +} + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index ebc3a17b7..726e54de9 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -233,3 +234,288 @@ func TestToClone(t *testing.T) { }) } } + +func TestVVReadToVV(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + wantBytes string + leftVV VectorisedView + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + wantBytes: "0123456789", + leftVV: vv(20, "01234567890123456789"), + }, + { + comment: "largeVV, multiple views, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + wantBytes: "123345", + leftVV: vv(7, "567", "8910"), + }, + { + comment: "smallVV (multiple views), large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + wantBytes: "123", + leftVV: vv(0, ""), + }, + { + comment: "smallVV (single view), large read", + vv: vv(1, "1"), + bytesToRead: 10, + wantBytes: "1", + leftVV: vv(0, ""), + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + wantBytes: "", + leftVV: vv(0, ""), + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + var readTo VectorisedView + inSize := tc.vv.Size() + copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead) + if got, want := copied, len(tc.wantBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc) + } + if got, want := string(readTo.ToView()), tc.wantBytes; got != want { + t.Errorf("unexpected content in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { + t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want) + } + }) + } +} + +func TestVVRead(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + readBytes string + leftBytes string + wantError bool + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + readBytes: "0123456789", + leftBytes: "01234567890123456789", + }, + { + comment: "largeVV, multiple buffers, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + readBytes: "123345", + leftBytes: "5678910", + }, + { + comment: "smallVV, large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + readBytes: "123", + leftBytes: "", + }, + { + comment: "smallVV, large read", + vv: vv(1, "1"), + bytesToRead: 10, + readBytes: "1", + leftBytes: "", + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + readBytes: "", + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + readTo := NewView(tc.bytesToRead) + inSize := tc.vv.Size() + copied, err := tc.vv.Read(readTo) + if !tc.wantError && err != nil { + t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err) + } + readTo = readTo[:copied] + if got, want := copied, len(tc.readBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(readTo), tc.readBytes; got != want { + t.Errorf("unexpected data in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want { + t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want) + } + }) + } +} + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} + +func TestToVectorisedView(t *testing.T) { + testCases := []struct { + in View + want VectorisedView + }{ + {nil, VectorisedView{}}, + {View{}, VectorisedView{}}, + {View{'a'}, VectorisedView{size: 1, views: []View{{'a'}}}}, + } + for _, tc := range testCases { + if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) + } + } +} + +func TestAppendView(t *testing.T) { + testCases := []struct { + vv VectorisedView + in View + want VectorisedView + }{ + {VectorisedView{}, nil, VectorisedView{}}, + {VectorisedView{}, View{}, VectorisedView{}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, nil, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{'e'}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}, {'e'}}, 5}}, + } + for _, tc := range testCases { + tc.vv.AppendView(tc.in) + if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) + } + } +} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index b6fa6fc37..c984470e6 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,12 +6,12 @@ go_library( name = "checker", testonly = 1, srcs = ["checker.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/checker", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 2f15bf1f1..b769094dc 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -33,6 +34,9 @@ type NetworkChecker func(*testing.T, []header.Network) // TransportChecker is a function to check a property of a transport packet. type TransportChecker func(*testing.T, header.Transport) +// ControlMessagesChecker is a function to check a property of ancillary data. +type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) + // IPv4 checks the validity and properties of the given IPv4 packet. It is // expected to be used in conjunction with other network checkers for specific // properties. For example, to check the source and destination address, one @@ -104,6 +108,8 @@ func DstAddr(addr tcpip.Address) NetworkChecker { // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). func TTL(ttl uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + var v uint8 switch ip := h[0].(type) { case header.IPv4: @@ -158,6 +164,44 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTClass creates a checker that checks the TCLASS field in +// ControlMessages. +func ReceiveTClass(want uint32) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTClass { + t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) + } else if got := cm.TClass; got != want { + t.Errorf("got cm.TClass = %d, want %d", got, want) + } + } +} + +// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. +func ReceiveTOS(want uint8) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTOS { + t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) + } else if got := cm.TOS; got != want { + t.Errorf("got cm.TOS = %d, want %d", got, want) + } + } +} + +// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in +// ControlMessages. +func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPPacketInfo { + t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) + } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { + t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -280,12 +324,30 @@ func SrcPort(port uint16) TransportChecker { // DstPort creates a checker that checks the destination port. func DstPort(port uint16) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + if p := h.DestinationPort(); p != port { t.Errorf("Bad destination port, got %v, want %v", p, port) } } } +// NoChecksum creates a checker that checks if the checksum is zero. +func NoChecksum(noChecksum bool) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + udp, ok := h.(header.UDP) + if !ok { + return + } + + if b := udp.Checksum() == 0; b != noChecksum { + t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) + } + } +} + // SeqNum creates a checker that checks the sequence number. func SeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { @@ -306,6 +368,7 @@ func SeqNum(seq uint32) TransportChecker { func AckNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -320,6 +383,8 @@ func AckNum(seq uint32) TransportChecker { // Window creates a checker that checks the tcp window. func Window(window uint16) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -351,6 +416,8 @@ func TCPFlags(flags uint8) TransportChecker { // given mask, match the supplied flags. func TCPFlagsMatch(flags, mask uint8) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -368,6 +435,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { // If wndscale is negative, the window scale option must not be present. func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -464,6 +533,8 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { // skipped. func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -582,6 +653,8 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { // Payload creates a checker that checks the payload. func Payload(want []byte) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + if got := h.Payload(); !reflect.DeepEqual(got, want) { t.Errorf("Wrong payload, got %v, want %v", got, want) } @@ -614,6 +687,7 @@ func ICMPv4(checkers ...TransportChecker) NetworkChecker { func ICMPv4Type(want header.ICMPv4Type) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv4, ok := h.(header.ICMPv4) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) @@ -625,9 +699,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { } // ICMPv4Code creates a checker that checks the ICMPv4 Code field. -func ICMPv4Code(want byte) TransportChecker { +func ICMPv4Code(want header.ICMPv4Code) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv4, ok := h.(header.ICMPv4) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) @@ -670,6 +745,7 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { func ICMPv6Type(want header.ICMPv6Type) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv6, ok := h.(header.ICMPv6) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) @@ -681,9 +757,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { } // ICMPv6Code creates a checker that checks the ICMPv6 Code field. -func ICMPv6Code(want byte) TransportChecker { +func ICMPv6Code(want header.ICMPv6Code) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv6, ok := h.(header.ICMPv6) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) @@ -698,7 +775,7 @@ func ICMPv6Code(want byte) TransportChecker { // message for type of ty, with potentially additional checks specified by // checkers. // -// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid // NDP message as far as the size of the message (minSize) is concerned. The // values within the message are up to checkers to validate. func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { @@ -730,9 +807,9 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N // Neighbor Solicitation message (as per the raw wire format), with potentially // additional checks specified by checkers. // -// checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNS message as far as the size of the messages concerned. The values within -// the message are up to checkers to validate. +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPNS message as far as the size of the message is concerned. The values +// within the message are up to checkers to validate. func NDPNS(checkers ...TransportChecker) NetworkChecker { return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) } @@ -750,7 +827,162 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { ns := header.NDPNeighborSolicit(icmp.NDPPayload()) if got := ns.TargetAddress(); got != want { - t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want) + t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) + } + } +} + +// NDPNA creates a checker that checks that the packet contains a valid NDP +// Neighbor Advertisement message (as per the raw wire format), with potentially +// additional checks specified by checkers. +// +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPNA message as far as the size of the message is concerned. The values +// within the message are up to checkers to validate. +func NDPNA(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...) +} + +// NDPNATargetAddress creates a checker that checks the Target Address field of +// a header.NDPNeighborAdvert. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNATargetAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + + if got := na.TargetAddress(); got != want { + t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) } } } + +// NDPNASolicitedFlag creates a checker that checks the Solicited field of +// a header.NDPNeighborAdvert. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNASolicitedFlag(want bool) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + + if got := na.SolicitedFlag(); got != want { + t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) + } + } +} + +// ndpOptions checks that optsBuf only contains opts. +func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { + t.Helper() + + it, err := optsBuf.Iter(true) + if err != nil { + t.Errorf("optsBuf.Iter(true): %s", err) + return + } + + i := 0 + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + t.Fatalf("unexpected error when iterating over NDP options: %s", err) + } + if done { + break + } + + if i >= len(opts) { + t.Errorf("got unexpected option: %s", opt) + continue + } + + switch wantOpt := opts[i].(type) { + case header.NDPSourceLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } + case header.NDPTargetLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } + default: + t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) + } + + i++ + } + + if missing := opts[i:]; len(missing) > 0 { + t.Errorf("missing options: %s", missing) + } +} + +// NDPNAOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Neighbor Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNAOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + ndpOptions(t, na.Options(), opts) + } +} + +// NDPNSOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Neighbor Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNS message as far as the size is concerned. +func NDPNSOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ndpOptions(t, ns.Options(), opts) + } +} + +// NDPRS creates a checker that checks that the packet contains a valid NDP +// Router Solicitation message (as per the raw wire format). +// +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPRS as far as the size of the message is concerned. The values within the +// message are up to checkers to validate. +func NDPRS(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) +} + +// NDPRSOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Router Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPRS message as far as the size is concerned. +func NDPRSOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + rs := header.NDPRouterSolicit(icmp.NDPPayload()) + ndpOptions(t, rs.Options(), opts) + } +} diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 0c5c20cea..ff2719291 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,15 +1,11 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "jenkins", srcs = ["jenkins.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) go_test( @@ -18,5 +14,5 @@ go_test( srcs = [ "jenkins_test.go", ], - embed = [":jenkins"], + library = ":jenkins", ) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index a3485b35c..d87797617 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -15,15 +14,17 @@ go_library( "interfaces.go", "ipv4.go", "ipv6.go", + "ipv6_extension_headers.go", "ipv6_fragment.go", "ndp_neighbor_advert.go", "ndp_neighbor_solicit.go", "ndp_options.go", "ndp_router_advert.go", + "ndp_router_solicit.go", + "ndpoptionidentifier_string.go", "tcp.go", "udp.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/header", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -38,12 +39,16 @@ go_test( size = "small", srcs = [ "checksum_test.go", + "ipv6_test.go", "ipversion_test.go", "tcp_test.go", ], deps = [ ":header", + "//pkg/rand", + "//pkg/tcpip", "//pkg/tcpip/buffer", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -52,8 +57,13 @@ go_test( size = "small", srcs = [ "eth_test.go", + "ipv6_extension_headers_test.go", "ndp_test.go", ], - embed = [":header"], - deps = ["//pkg/tcpip"], + library = ":header", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "@com_github_google_go_cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go index 718a4720a..83189676e 100644 --- a/pkg/tcpip/header/arp.go +++ b/pkg/tcpip/header/arp.go @@ -14,14 +14,33 @@ package header -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // ARPProtocolNumber is the ARP network protocol number. ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 // ARPSize is the size of an IPv4-over-Ethernet ARP packet. - ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 + ARPSize = 28 +) + +// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header. +type ARPHardwareType uint16 + +// Typical ARP HardwareType values. Some of the constants have to be specific +// values as they are egressed on the wire in the HTYPE field of an ARP header. +const ( + ARPHardwareNone ARPHardwareType = 0 + // ARPHardwareEther specifically is the HTYPE for Ethernet as specified + // in the IANA list here: + // + // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2 + ARPHardwareEther ARPHardwareType = 1 + ARPHardwareLoopback ARPHardwareType = 2 ) // ARPOp is an ARP opcode. @@ -36,54 +55,64 @@ const ( // ARP is an ARP packet stored in a byte array as described in RFC 826. type ARP []byte -func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } -func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } -func (a ARP) hardwareAddressSize() int { return int(a[4]) } -func (a ARP) protocolAddressSize() int { return int(a[5]) } +const ( + hTypeOffset = 0 + protocolOffset = 2 + haAddressSizeOffset = 4 + protoAddressSizeOffset = 5 + opCodeOffset = 6 + senderHAAddressOffset = 8 + senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize + targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize + targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize +) + +func (a ARP) hardwareAddressType() ARPHardwareType { + return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:])) +} + +func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) } +func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) } +func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) } // Op is the ARP opcode. -func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } +func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) } // SetOp sets the ARP opcode. func (a ARP) SetOp(op ARPOp) { - a[6] = uint8(op >> 8) - a[7] = uint8(op) + binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op)) } // SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet. func (a ARP) SetIPv4OverEthernet() { - a[0], a[1] = 0, 1 // htypeEthernet - a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber - a[4] = 6 // macSize - a[5] = uint8(IPv4AddressSize) + binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther)) + binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber)) + a[haAddressSizeOffset] = EthernetAddressSize + a[protoAddressSizeOffset] = uint8(IPv4AddressSize) } // HardwareAddressSender is the link address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressSender() []byte { - const s = 8 - return a[s : s+6] + return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize] } // ProtocolAddressSender is the protocol address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressSender() []byte { - const s = 8 + 6 - return a[s : s+4] + return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize] } // HardwareAddressTarget is the link address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressTarget() []byte { - const s = 8 + 6 + 4 - return a[s : s+6] + return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize] } // ProtocolAddressTarget is the protocol address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressTarget() []byte { - const s = 8 + 6 + 4 + 6 - return a[s : s+4] + return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize] } // IsValid reports whether this is an ARP packet for IPv4 over Ethernet. @@ -91,10 +120,8 @@ func (a ARP) IsValid() bool { if len(a) < ARPSize { return false } - const htypeEthernet = 1 - const macSize = 6 - return a.hardwareAddressSpace() == htypeEthernet && + return a.hardwareAddressType() == ARPHardwareEther && a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && - a.hardwareAddressSize() == macSize && + a.hardwareAddressSize() == EthernetAddressSize && a.protocolAddressSize() == IPv4AddressSize } diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 9749c7f4d..14a4b2b44 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -45,12 +45,139 @@ func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { return ChecksumCombine(uint16(v), uint16(v>>16)), odd } +func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { + v := initial + + if odd { + v += uint32(buf[0]) + buf = buf[1:] + } + + l := len(buf) + odd = l&1 != 0 + if odd { + l-- + v += uint32(buf[l]) << 8 + } + for (l - 64) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[64:] + l = l - 64 + } + if (l - 32) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[32:] + l = l - 32 + } + if (l - 16) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[16:] + l = l - 16 + } + if (l - 8) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + buf = buf[8:] + l = l - 8 + } + if (l - 4) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + buf = buf[4:] + l = l - 4 + } + + // At this point since l was even before we started unrolling + // there can be only two bytes left to add. + if l != 0 { + v += (uint32(buf[0]) << 8) + uint32(buf[1]) + } + + return ChecksumCombine(uint16(v), uint16(v>>16)), odd +} + +// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in +// the given byte array. This function uses a non-optimized implementation. Its +// only retained for reference and to use as a benchmark/test. Most code should +// use the header.Checksum function. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumOld(buf []byte, initial uint16) uint16 { + s, _ := calculateChecksum(buf, false, uint32(initial)) + return s +} + // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the -// given byte array. +// given byte array. This function uses an optimized unrolled version of the +// checksum algorithm. // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { - s, _ := calculateChecksum(buf, false, uint32(initial)) + s, _ := unrolledCalculateChecksum(buf, false, uint32(initial)) return s } @@ -86,7 +213,7 @@ func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, siz } v = v[:l] - sum, odd = calculateChecksum(v, odd, uint32(sum)) + sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum)) size -= len(v) if size == 0 { diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 86b466c1c..309403482 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -17,6 +17,8 @@ package header_test import ( + "fmt" + "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -107,3 +109,63 @@ func TestChecksumVVWithOffset(t *testing.T) { }) } } + +func TestChecksum(t *testing.T) { + var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} + type testCase struct { + buf []byte + initial uint16 + csumOrig uint16 + csumNew uint16 + } + testCases := make([]testCase, 100000) + // Ensure same buffer generation for test consistency. + rnd := rand.New(rand.NewSource(42)) + for i := range testCases { + testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) + testCases[i].initial = uint16(rnd.Intn(65536)) + rnd.Read(testCases[i].buf) + } + + for i := range testCases { + testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) + testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) + if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { + t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) + } + } +} + +func BenchmarkChecksum(b *testing.B) { + var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} + + checkSumImpls := []struct { + fn func([]byte, uint16) uint16 + name string + }{ + {header.ChecksumOld, fmt.Sprintf("checksum_old")}, + {header.Checksum, fmt.Sprintf("checksum")}, + } + + for _, csumImpl := range checkSumImpls { + // Ensure same buffer generation for test consistency. + rnd := rand.New(rand.NewSource(42)) + for _, bufSz := range bufSizes { + b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { + tc := struct { + buf []byte + initial uint16 + csum uint16 + }{ + buf: make([]byte, bufSz), + initial: uint16(rnd.Intn(65536)), + } + rnd.Read(tc.buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + tc.csum = csumImpl.fn(tc.buf, tc.initial) + } + }) + } + } +} diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index f5d2c127f..eaface8cb 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -53,6 +53,10 @@ const ( // (all bits set to 0). unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + // EthernetBroadcastAddress is an ethernet address that addresses every node + // on a local link. + EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff") + // unicastMulticastFlagMask is the mask of the least significant bit in // the first octet (in network byte order) of an ethernet address that // determines whether the ethernet address is a unicast or multicast. If @@ -134,3 +138,44 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { // addr is a valid unicast ethernet address. return true } + +// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address +// for a multicast IPv4 address. +// +// addr MUST be a multicast IPv4 address. +func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress { + var linkAddrBytes [EthernetAddressSize]byte + // RFC 1112 Host Extensions for IP Multicasting + // + // 6.4. Extensions to an Ethernet Local Network Module: + // + // An IP host group address is mapped to an Ethernet multicast + // address by placing the low-order 23-bits of the IP address + // into the low-order 23 bits of the Ethernet multicast address + // 01-00-5E-00-00-00 (hex). + linkAddrBytes[0] = 0x1 + linkAddrBytes[2] = 0x5e + linkAddrBytes[3] = addr[1] & 0x7F + copy(linkAddrBytes[4:], addr[IPv4AddressSize-2:]) + return tcpip.LinkAddress(linkAddrBytes[:]) +} + +// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address +// for a multicast IPv6 address. +// +// addr MUST be a multicast IPv6 address. +func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress { + // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks + // + // 7. Address Mapping -- Multicast + // + // An IPv6 packet with a multicast destination address DST, + // consisting of the sixteen octets DST[1] through DST[16], is + // transmitted to the Ethernet multicast address whose first + // two octets are the value 3333 hexadecimal and whose last + // four octets are the last four octets of DST. + linkAddrBytes := []byte(addr[IPv6AddressSize-EthernetAddressSize:]) + linkAddrBytes[0] = 0x33 + linkAddrBytes[1] = 0x33 + return tcpip.LinkAddress(linkAddrBytes[:]) +} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 6634c90f5..14413f2ce 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -66,3 +66,37 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) { }) } } + +func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4 Multicast without 24th bit set", + addr: "\xe0\x7e\xdc\xba", + expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", + }, + { + name: "IPv4 Multicast with 24th bit set", + addr: "\xe0\xfe\xdc\xba", + expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr { + t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr) + } + }) + } +} + +func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) { + addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a") + if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want { + t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want) + } +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 0cac6c0a5..be03fb086 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -54,6 +54,9 @@ const ( // ICMPv4Type is the ICMP type field described in RFC 792. type ICMPv4Type byte +// ICMPv4Code is the ICMP code field described in RFC 792. +type ICMPv4Code byte + // Typical values of ICMPv4Type defined in RFC 792. const ( ICMPv4EchoReply ICMPv4Type = 0 @@ -69,10 +72,13 @@ const ( ICMPv4InfoReply ICMPv4Type = 16 ) -// Values for ICMP code as defined in RFC 792. +// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. const ( - ICMPv4PortUnreachable = 3 - ICMPv4FragmentationNeeded = 4 + ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4HostUnreachable ICMPv4Code = 1 + ICMPv4ProtoUnreachable ICMPv4Code = 2 + ICMPv4PortUnreachable ICMPv4Code = 3 + ICMPv4FragmentationNeeded ICMPv4Code = 4 ) // Type is the ICMP type field. @@ -82,10 +88,10 @@ func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) } // Code is the ICMP code field. Its meaning depends on the value of Type. -func (b ICMPv4) Code() byte { return b[1] } +func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. -func (b ICMPv4) SetCode(c byte) { b[1] = c } +func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index b4037b6c8..20b01d8f4 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -52,7 +52,7 @@ const ( // ICMPv6NeighborAdvertSize is size of a neighbor advertisement // including the NDP Target Link Layer option for an Ethernet // address. - ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize + ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. ICMPv6EchoMinimumSize = 8 @@ -92,7 +92,6 @@ const ( // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. type ICMPv6Type byte -// Typical values of ICMPv6Type defined in RFC 4443. const ( ICMPv6DstUnreachable ICMPv6Type = 1 ICMPv6PacketTooBig ICMPv6Type = 2 @@ -110,11 +109,38 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// Values for ICMP code as defined in RFC 4443. +// ICMPv6Code is the ICMP code field described in RFC 4443. +type ICMPv6Code byte + +// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443 +// section 3.1. +const ( + ICMPv6NetworkUnreachable ICMPv6Code = 0 + ICMPv6Prohibited ICMPv6Code = 1 + ICMPv6BeyondScope ICMPv6Code = 2 + ICMPv6AddressUnreachable ICMPv6Code = 3 + ICMPv6PortUnreachable ICMPv6Code = 4 + ICMPv6Policy ICMPv6Code = 5 + ICMPv6RejectRoute ICMPv6Code = 6 +) + +// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3. const ( - ICMPv6PortUnreachable = 4 + ICMPv6HopLimitExceeded ICMPv6Code = 0 + ICMPv6ReassemblyTimeout ICMPv6Code = 1 ) +// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4. +const ( + ICMPv6ErroneousHeader ICMPv6Code = 0 + ICMPv6UnknownHeader ICMPv6Code = 1 + ICMPv6UnknownOption ICMPv6Code = 2 +) + +// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use +// the code field. (Types not mentioned above.) +const ICMPv6UnusedCode ICMPv6Code = 0 + // Type is the ICMP type field. func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } @@ -122,10 +148,10 @@ func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) } // Code is the ICMP code field. Its meaning depends on the value of Type. -func (b ICMPv6) Code() byte { return b[1] } +func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } // SetCode sets the ICMP code field. -func (b ICMPv6) SetCode(c byte) { b[1] = c } +func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e5360e7c1..680eafd16 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -38,7 +38,8 @@ const ( // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the // fields of a packet that needs to be encoded. type IPv4Fields struct { - // IHL is the "internet header length" field of an IPv4 packet. + // IHL is the "internet header length" field of an IPv4 packet. The value + // is in bytes. IHL uint8 // TOS is the "type of service" field of an IPv4 packet. @@ -100,6 +101,11 @@ const ( // IPv4Version is the version of the ipv4 protocol. IPv4Version = 4 + // IPv4AllSystems is the all systems IPv4 multicast address as per + // IANA's IPv4 Multicast Address Space Registry. See + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml. + IPv4AllSystems tcpip.Address = "\xe0\x00\x00\x01" + // IPv4Broadcast is the broadcast address of the IPv4 procotol. IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff" @@ -138,7 +144,7 @@ func IPVersion(b []byte) int { } // HeaderLength returns the value of the "header length" field of the ipv4 -// header. +// header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { return (b[versIHL] & 0xf) * 4 } @@ -158,6 +164,11 @@ func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } +// More returns whether the more fragments flag is set. +func (b IPv4) More() bool { + return b.Flags()&IPv4FlagMoreFragments != 0 +} + // TTL returns the "TTL" field of the ipv4 header. func (b IPv4) TTL() uint8 { return b[ttl] @@ -304,3 +315,12 @@ func IsV4MulticastAddress(addr tcpip.Address) bool { } return (addr[0] & 0xf0) == 0xe0 } + +// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback +// address (belongs to 127.0.0.1/8 subnet). +func IsV4LoopbackAddress(addr tcpip.Address) bool { + if len(addr) != IPv4AddressSize { + return false + } + return addr[0] == 0x7f +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index f1e60911b..ea3823898 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -15,7 +15,9 @@ package header import ( + "crypto/sha256" "encoding/binary" + "fmt" "strings" "gvisor.dev/gvisor/pkg/tcpip" @@ -26,7 +28,9 @@ const ( // IPv6PayloadLenOffset is the offset of the PayloadLength field in // IPv6 header. IPv6PayloadLenOffset = 4 - nextHdr = 6 + // IPv6NextHeaderOffset is the offset of the NextHeader field in + // IPv6 header. + IPv6NextHeaderOffset = 6 hopLimit = 7 v6SrcAddr = 8 v6DstAddr = v6SrcAddr + IPv6AddressSize @@ -83,16 +87,58 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6AllRoutersMulticastAddress is a link-local multicast group that + // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers on a link. + // + // The address is ff02::2. + IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 + // IPv6Loopback is the IPv6 Loopback address. + IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + // IIDSize is the size of an interface identifier (IID), in bytes, as + // defined by RFC 4291 section 2.5.1. + IIDSize = 8 + + // IIDOffsetInIPv6Address is the offset, in bytes, from the start + // of an IPv6 address to the beginning of the interface identifier + // (IID) for auto-generated addresses. That is, all bytes before + // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all + // bytes including and after the IIDOffsetInIPv6Address-th byte are + // for the IID. + IIDOffsetInIPv6Address = 8 + + // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes + // for the secret key used to generate an opaque interface identifier as + // outlined by RFC 7217. + OpaqueIIDSecretKeyMinBytes = 16 + + // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field + // is located within a multicast IPv6 address, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeByteIdx = 1 + + // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, + // within the byte holding the field, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeMask = 0xF + + // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within + // a multicast IPv6 address that indicates the address has link-local scope, + // as per RFC 4291 section 2.7. + ipv6LinkLocalMulticastScope = 2 ) -// IPv6EmptySubnet is the empty IPv6 subnet. +// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the +// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to +// be contained within this subnet. var IPv6EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) if err != nil { @@ -123,7 +169,7 @@ func (b IPv6) HopLimit() uint8 { // NextHeader returns the value of the "next header" field of the ipv6 header. func (b IPv6) NextHeader() uint8 { - return b[nextHdr] + return b[IPv6NextHeaderOffset] } // TransportProtocol implements Network.TransportProtocol. @@ -183,7 +229,7 @@ func (b IPv6) SetDestinationAddress(addr tcpip.Address) { // SetNextHeader sets the value of the "next header" field of the ipv6 header. func (b IPv6) SetNextHeader(v uint8) { - b[nextHdr] = v + b[IPv6NextHeaderOffset] = v } // SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a @@ -195,7 +241,7 @@ func (IPv6) SetChecksum(uint16) { func (b IPv6) Encode(i *IPv6Fields) { b.SetTOS(i.TrafficClass, i.FlowLabel) b.SetPayloadLength(i.PayloadLength) - b[nextHdr] = i.NextHeader + b[IPv6NextHeaderOffset] = i.NextHeader b[hopLimit] = i.HopLimit b.SetSourceAddress(i.SrcAddr) b.SetDestinationAddress(i.DstAddr) @@ -264,27 +310,43 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 +// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. +// +// buf MUST be at least 8 bytes. +func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { + buf[0] = linkAddr[0] ^ 2 + buf[1] = linkAddr[1] + buf[2] = linkAddr[2] + buf[3] = 0xFF + buf[4] = 0xFE + buf[5] = linkAddr[3] + buf[6] = linkAddr[4] + buf[7] = linkAddr[5] +} + +// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit +// Ethernet/MAC address, as per RFC 4291 section 2.5.1. +func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { + var buf [IIDSize]byte + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + return buf +} + // LinkLocalAddr computes the default IPv6 link-local address from a link-layer // (MAC) address. func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { - // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local - // header, FE80::. + // Convert a 48-bit MAC to a modified EUI-64 and then prepend the + // link-local header, FE80::. // // The conversion is very nearly: // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff // Note the capital A. The conversion aa->Aa involves a bit flip. - lladdrb := [16]byte{ - 0: 0xFE, - 1: 0x80, - 8: linkAddr[0] ^ 2, - 9: linkAddr[1], - 10: linkAddr[2], - 11: 0xFF, - 12: 0xFE, - 13: linkAddr[3], - 14: linkAddr[4], - 15: linkAddr[5], + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, } + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) return tcpip.Address(lladdrb[:]) } @@ -296,3 +358,145 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } + +// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 +// link-local multicast address. +func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { + return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope +} + +// IsV6UniqueLocalAddress determines if the provided address is an IPv6 +// unique-local address (within the prefix FC00::/7). +func IsV6UniqueLocalAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + // According to RFC 4193 section 3.1, a unique local address has the prefix + // FC00::/7. + return (addr[0] & 0xfe) == 0xfc +} + +// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier +// (IID) to buf as outlined by RFC 7217 and returns the extended buffer. +// +// The opaque IID is generated from the cryptographic hash of the concatenation +// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret +// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and +// MUST be generated to a pseudo-random number. See RFC 4086 for randomness +// requirements for security. +// +// If buf has enough capacity for the IID (IIDSize bytes), a new underlying +// array for the buffer will not be allocated. +func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte { + // As per RFC 7217 section 5, the opaque identifier can be generated as a + // cryptographic hash of the concatenation of each of the function parameters. + // Note, we omit the optional Network_ID field. + h := sha256.New() + // h.Write never returns an error. + h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address])) + h.Write([]byte(nicName)) + h.Write([]byte{dadCounter}) + h.Write(secretKey) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + return append(buf, sum[:IIDSize]...) +} + +// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with +// an opaque IID. +func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address { + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)) +} + +// IPv6AddressScope is the scope of an IPv6 address. +type IPv6AddressScope int + +const ( + // LinkLocalScope indicates a link-local address. + LinkLocalScope IPv6AddressScope = iota + + // UniqueLocalScope indicates a unique-local address. + UniqueLocalScope + + // GlobalScope indicates a global address. + GlobalScope +) + +// ScopeForIPv6Address returns the scope for an IPv6 address. +func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { + if len(addr) != IPv6AddressSize { + return GlobalScope, tcpip.ErrBadAddress + } + + switch { + case IsV6LinkLocalMulticastAddress(addr): + return LinkLocalScope, nil + + case IsV6LinkLocalAddress(addr): + return LinkLocalScope, nil + + case IsV6UniqueLocalAddress(addr): + return UniqueLocalScope, nil + + default: + return GlobalScope, nil + } +} + +// InitialTempIID generates the initial temporary IID history value to generate +// temporary SLAAC addresses with. +// +// Panics if initialTempIIDHistory is not at least IIDSize bytes. +func InitialTempIID(initialTempIIDHistory []byte, seed []byte, nicID tcpip.NICID) { + h := sha256.New() + // h.Write never returns an error. + h.Write(seed) + var nicIDBuf [4]byte + binary.BigEndian.PutUint32(nicIDBuf[:], uint32(nicID)) + h.Write(nicIDBuf[:]) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + if n := copy(initialTempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } +} + +// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an +// associated stable/permanent SLAAC address. +// +// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be +// used when generating a new temporary IID. +// +// Panics if tempIIDHistory is not at least IIDSize bytes. +func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix { + addrBytes := []byte(stableAddr) + h := sha256.New() + h.Write(tempIIDHistory) + h.Write(addrBytes[IIDOffsetInIPv6Address:]) + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + // The rightmost 64 bits of sum are saved for the next iteration. + if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } + + // The leftmost 64 bits of sum is used as the IID. + if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize { + panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize)) + } + + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: IIDOffsetInIPv6Address * 8, + } +} diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go new file mode 100644 index 000000000..3499d8399 --- /dev/null +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -0,0 +1,551 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier. +type IPv6ExtensionHeaderIdentifier uint8 + +const ( + // IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by + // Hop Options extension header, as per RFC 8200 section 4.3. + IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0 + + // IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension + // header, as per RFC 8200 section 4.4. + IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43 + + // IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment + // extension header, as per RFC 8200 section 4.5. + IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44 + + // IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a + // Destination Options extension header, as per RFC 8200 section 4.6. + IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60 + + // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end + // of an IPv6 payload, as per RFC 8200 section 4.7. + IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59 +) + +const ( + // ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when + // a node encounters an unrecognized option. + ipv6UnknownExtHdrOptionActionMask = 192 + + // ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard + // from the action value for an unrecognized option identifier. + ipv6UnknownExtHdrOptionActionShift = 6 + + // ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field + // within an IPv6RoutingExtHdr. + ipv6RoutingExtHdrSegmentsLeftIdx = 1 + + // IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in + // bytes. + IPv6FragmentExtHdrLength = 8 + + // ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the + // Fragment Offset field within an IPv6FragmentExtHdr. + ipv6FragmentExtHdrFragmentOffsetOffset = 0 + + // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to + // discard from the Fragment Offset. + ipv6FragmentExtHdrFragmentOffsetShift = 3 + + // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an + // IPv6FragmentExtHdr. + ipv6FragmentExtHdrFlagsIdx = 1 + + // ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the + // flags field of an IPv6FragmentExtHdr. + ipv6FragmentExtHdrMFlagMask = 1 + + // ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification + // field within an IPv6FragmentExtHdr. + ipv6FragmentExtHdrIdentificationOffset = 2 + + // ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length + // field. That is, given a Length field of 2, the extension header expects + // 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for + // details about the first 8 bytes' exclusion from the Length field). + ipv6ExtHdrLenBytesPerUnit = 8 + + // ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an + // extension header's Length field following the Length field. + // + // The Length field excludes the first 8 bytes, but the Next Header and Length + // field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes + // after the Length field. + // + // This ensures that every extension header is at least 8 bytes. + ipv6ExtHdrLenBytesExcluded = 6 + + // IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment + // extension header's Fragment Offset field. That is, given a Fragment Offset + // of 2, the extension header is indiciating that the fragment's payload + // starts at the 16th byte in the reassembled packet. + IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8 +) + +// IPv6PayloadHeader is implemented by the various headers that can be found +// in an IPv6 payload. +// +// These headers include IPv6 extension headers or upper layer data. +type IPv6PayloadHeader interface { + isIPv6PayloadHeader() +} + +// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator +// encounters a Next Header field it does not recognize as an IPv6 extension +// header. +type IPv6RawPayloadHeader struct { + Identifier IPv6ExtensionHeaderIdentifier + Buf buffer.VectorisedView +} + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {} + +// ipv6OptionsExtHdr is an IPv6 extension header that holds options. +type ipv6OptionsExtHdr []byte + +// Iter returns an iterator over the IPv6 extension header options held in b. +func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator { + it := IPv6OptionsExtHdrOptionsIterator{} + it.reader.Reset(b) + return it +} + +// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header +// options. +// +// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last +// used, no changes to the underlying buffer may happen. Doing so may cause +// undefined and unexpected behaviour. It is fine to obtain an +// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then +// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator +// obtained before modification is no longer used. +type IPv6OptionsExtHdrOptionsIterator struct { + reader bytes.Reader +} + +// IPv6OptionUnknownAction is the action that must be taken if the processing +// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2. +type IPv6OptionUnknownAction int + +const ( + // IPv6OptionUnknownActionSkip indicates that the unrecognized option must + // be skipped and the node should continue processing the header. + IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0 + + // IPv6OptionUnknownActionDiscard indicates that the packet must be silently + // discarded. + IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1 + + // IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be + // discarded and the node must send an ICMP Parameter Problem, Code 2, message + // to the packet's source, regardless of whether or not the packet's + // Destination was a multicast address. + IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2 + + // IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the + // packet must be discarded and the node must send an ICMP Parameter Problem, + // Code 2, message to the packet's source only if the packet's Destination was + // not a multicast address. + IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3 +) + +// IPv6ExtHdrOption is implemented by the various IPv6 extension header options. +type IPv6ExtHdrOption interface { + // UnknownAction returns the action to take in response to an unrecognized + // option. + UnknownAction() IPv6OptionUnknownAction + + // isIPv6ExtHdrOption is used to "lock" this interface so it is not + // implemented by other packages. + isIPv6ExtHdrOption() +} + +// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier. +type IPv6ExtHdrOptionIndentifier uint8 + +const ( + // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that + // provides 1 byte padding, as outlined in RFC 8200 section 4.2. + ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0 + + // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that + // provides variable length byte padding, as outlined in RFC 8200 section 4.2. + ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1 +) + +// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension +// header option that is unknown by the parsing utilities. +type IPv6UnknownExtHdrOption struct { + Identifier IPv6ExtHdrOptionIndentifier + Data []byte +} + +// UnknownAction implements IPv6OptionUnknownAction.UnknownAction. +func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction { + return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) +} + +// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption. +func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {} + +// Next returns the next option in the options data. +// +// If the next item is not a known extension header option, +// IPv6UnknownExtHdrOption will be returned with the option identifier and data. +// +// The return is of the format (option, done, error). done will be true when +// Next is unable to return anything because the iterator has reached the end of +// the options data, or an error occured. +func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) { + for { + temp, err := i.reader.ReadByte() + if err != nil { + // If we can't read the first byte of a new option, then we know the + // options buffer has been exhausted and we are done iterating. + return nil, true, nil + } + id := IPv6ExtHdrOptionIndentifier(temp) + + // If the option identifier indicates the option is a Pad1 option, then we + // know the option does not have Length and Data fields. End processing of + // the Pad1 option and continue processing the buffer as a new option. + if id == ipv6Pad1ExtHdrOptionIdentifier { + continue + } + + length, err := i.reader.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err)) + } + + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once + // we start parsing an option; we expect the reader to contain enough + // bytes for the whole option. + return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF) + } + + // Special-case the variable length padding option to avoid a copy. + if id == ipv6PadNExtHdrOptionIdentifier { + // Do we have enough bytes in the reader for the PadN option? + if n := i.reader.Len(); n < int(length) { + // Reset the reader to effectively consume the remaining buffer. + i.reader.Reset(nil) + + // We return the same error as if we failed to read a non-padding option + // so consumers of this iterator don't need to differentiate between + // padding and non-padding options. + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) + } + + if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil { + panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) + } + + // End processing of the PadN option and continue processing the buffer as + // a new option. + continue + } + + bytes := make([]byte, length) + if n, err := io.ReadFull(&i.reader, bytes); err != nil { + // io.ReadFull may return io.EOF if i.reader has been exhausted. We use + // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the + // Length field found in the option. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) + } + + return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil + } +} + +// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options +// extension header. +type IPv6HopByHopOptionsExtHdr struct { + ipv6OptionsExtHdr +} + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {} + +// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options +// extension header. +type IPv6DestinationOptionsExtHdr struct { + ipv6OptionsExtHdr +} + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {} + +// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific +// data as outlined in RFC 8200 section 4.4. +type IPv6RoutingExtHdr []byte + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {} + +// SegmentsLeft returns the Segments Left field. +func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 { + return b[ipv6RoutingExtHdrSegmentsLeftIdx] +} + +// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific +// data as outlined in RFC 8200 section 4.5. +// +// Note, the buffer does not include the Next Header and Reserved fields. +type IPv6FragmentExtHdr [6]byte + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {} + +// FragmentOffset returns the Fragment Offset field. +// +// This value indicates where the buffer following the Fragment extension header +// starts in the target (reassembled) packet. +func (b IPv6FragmentExtHdr) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift +} + +// More returns the More (M) flag. +// +// This indicates whether any fragments are expected to succeed b. +func (b IPv6FragmentExtHdr) More() bool { + return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0 +} + +// ID returns the Identification field. +// +// This value is used to uniquely identify the packet, between a +// souce and destination. +func (b IPv6FragmentExtHdr) ID() uint32 { + return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:]) +} + +// IsAtomic returns whether the fragment header indicates an atomic fragment. An +// atomic fragment is a fragment that contains all the data required to +// reassemble a full packet. +func (b IPv6FragmentExtHdr) IsAtomic() bool { + return !b.More() && b.FragmentOffset() == 0 +} + +// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload. +// +// The IPv6 payload may contain IPv6 extension headers before any upper layer +// data. +// +// Note, between when an IPv6PayloadIterator is obtained and last used, no +// changes to the payload may happen. Doing so may cause undefined and +// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate +// over the first few headers then modify the backing payload so long as the +// IPv6PayloadIterator obtained before modification is no longer used. +type IPv6PayloadIterator struct { + // The identifier of the next header to parse. + nextHdrIdentifier IPv6ExtensionHeaderIdentifier + + // reader is an io.Reader over payload. + reader bufio.Reader + payload buffer.VectorisedView + + // Indicates to the iterator that it should return the remaining payload as a + // raw payload on the next call to Next. + forceRaw bool +} + +// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing +// extension headers, or a raw payload if the payload cannot be parsed. +func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.VectorisedView) IPv6PayloadIterator { + readers := payload.Readers() + readerPs := make([]io.Reader, 0, len(readers)) + for i := range readers { + readerPs = append(readerPs, &readers[i]) + } + + return IPv6PayloadIterator{ + nextHdrIdentifier: nextHdrIdentifier, + payload: payload.Clone(nil), + // We need a buffer of size 1 for calls to bufio.Reader.ReadByte. + reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + } +} + +// AsRawHeader returns the remaining payload of i as a raw header and +// optionally consumes the iterator. +// +// If consume is true, calls to Next after calling AsRawHeader on i will +// indicate that the iterator is done. +func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { + identifier := i.nextHdrIdentifier + + var buf buffer.VectorisedView + if consume { + // Since we consume the iterator, we return the payload as is. + buf = i.payload + + // Mark i as done. + *i = IPv6PayloadIterator{ + nextHdrIdentifier: IPv6NoNextHeaderIdentifier, + } + } else { + buf = i.payload.Clone(nil) + } + + return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf} +} + +// Next returns the next item in the payload. +// +// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader +// will be returned with the remaining bytes and next header identifier. +// +// The return is of the format (header, done, error). done will be true when +// Next is unable to return anything because the iterator has reached the end of +// the payload, or an error occured. +func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { + // We could be forced to return i as a raw header when the previous header was + // a fragment extension header as the data following the fragment extension + // header may not be complete. + if i.forceRaw { + return i.AsRawHeader(true /* consume */), false, nil + } + + // Is the header we are parsing a known extension header? + switch i.nextHdrIdentifier { + case IPv6HopByHopOptionsExtHdrIdentifier: + nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil) + if err != nil { + return nil, true, err + } + + i.nextHdrIdentifier = nextHdrIdentifier + return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil + case IPv6RoutingExtHdrIdentifier: + nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil) + if err != nil { + return nil, true, err + } + + i.nextHdrIdentifier = nextHdrIdentifier + return IPv6RoutingExtHdr(bytes), false, nil + case IPv6FragmentExtHdrIdentifier: + var data [6]byte + // We ignore the returned bytes becauase we know the fragment extension + // header specific data will fit in data. + nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) + if err != nil { + return nil, true, err + } + + fragmentExtHdr := IPv6FragmentExtHdr(data) + + // If the packet is not the first fragment, do not attempt to parse anything + // after the fragment extension header as the payload following the fragment + // extension header should not contain any headers; the first fragment must + // hold all the headers up to and including any upper layer headers, as per + // RFC 8200 section 4.5. + if fragmentExtHdr.FragmentOffset() != 0 { + i.forceRaw = true + } + + i.nextHdrIdentifier = nextHdrIdentifier + return fragmentExtHdr, false, nil + case IPv6DestinationOptionsExtHdrIdentifier: + nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil) + if err != nil { + return nil, true, err + } + + i.nextHdrIdentifier = nextHdrIdentifier + return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil + case IPv6NoNextHeaderIdentifier: + // This indicates the end of the IPv6 payload. + return nil, true, nil + + default: + // The header we are parsing is not a known extension header. Return the + // raw payload. + return i.AsRawHeader(true /* consume */), false, nil + } +} + +// nextHeaderData returns the extension header's Next Header field and raw data. +// +// fragmentHdr indicates that the extension header being parsed is the Fragment +// extension header so the Length field should be ignored as it is Reserved +// for the Fragment extension header. +// +// If bytes is not nil, extension header specific data will be read into bytes +// if it has enough capacity. If bytes is provided but does not have enough +// capacity for the data, nextHeaderData will panic. +func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, []byte, error) { + // We ignore the number of bytes read because we know we will only ever read + // at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read + // would return io.EOF to indicate that io.Reader has reached the end of the + // payload. + nextHdrIdentifier, err := i.reader.ReadByte() + i.payload.TrimFront(1) + if err != nil { + return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } + + var length uint8 + length, err = i.reader.ReadByte() + i.payload.TrimFront(1) + if err != nil { + if fragmentHdr { + return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } + + return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } + if fragmentHdr { + length = 0 + } + + bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded + if bytes == nil { + bytes = make([]byte, bytesLen) + } else if n := len(bytes); n < bytesLen { + panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier)) + } + + n, err := io.ReadFull(&i.reader, bytes) + i.payload.TrimFront(n) + if err != nil { + return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err) + } + + return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil +} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go new file mode 100644 index 000000000..ab20c5f37 --- /dev/null +++ b/pkg/tcpip/header/ipv6_extension_headers_test.go @@ -0,0 +1,992 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +// Equal returns true of a and b are equivalent. +// +// Note, Equal will return true if a and b hold the same Identifier value and +// contain the same bytes in Buf, even if the bytes are split across views +// differently. +// +// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported +// fields. +func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool { + return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView()) +} + +// Equal returns true of a and b are equivalent. +// +// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs. +// +// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported +// fields. +func (a IPv6HopByHopOptionsExtHdr) Equal(b IPv6HopByHopOptionsExtHdr) bool { + return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr) +} + +// Equal returns true of a and b are equivalent. +// +// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs. +// +// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported +// fields. +func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool { + return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr) +} + +func TestIPv6UnknownExtHdrOption(t *testing.T) { + tests := []struct { + name string + identifier IPv6ExtHdrOptionIndentifier + expectedUnknownAction IPv6OptionUnknownAction + }{ + { + name: "Skip with zero LSBs", + identifier: 0, + expectedUnknownAction: IPv6OptionUnknownActionSkip, + }, + { + name: "Discard with zero LSBs", + identifier: 64, + expectedUnknownAction: IPv6OptionUnknownActionDiscard, + }, + { + name: "Discard and ICMP with zero LSBs", + identifier: 128, + expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP, + }, + { + name: "Discard and ICMP for non multicast destination with zero LSBs", + identifier: 192, + expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + }, + { + name: "Skip with non-zero LSBs", + identifier: 63, + expectedUnknownAction: IPv6OptionUnknownActionSkip, + }, + { + name: "Discard with non-zero LSBs", + identifier: 127, + expectedUnknownAction: IPv6OptionUnknownActionDiscard, + }, + { + name: "Discard and ICMP with non-zero LSBs", + identifier: 191, + expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP, + }, + { + name: "Discard and ICMP for non multicast destination with non-zero LSBs", + identifier: 255, + expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opt := &IPv6UnknownExtHdrOption{Identifier: test.identifier, Data: []byte{1, 2, 3, 4}} + if a := opt.UnknownAction(); a != test.expectedUnknownAction { + t.Fatalf("got UnknownAction() = %d, want = %d", a, test.expectedUnknownAction) + } + }) + } + +} + +func TestIPv6OptionsExtHdrIterErr(t *testing.T) { + tests := []struct { + name string + bytes []byte + err error + }{ + { + name: "Single unknown with zero length", + bytes: []byte{255, 0}, + }, + { + name: "Single unknown with non-zero length", + bytes: []byte{255, 3, 1, 2, 3}, + }, + { + name: "Two options", + bytes: []byte{ + 255, 0, + 254, 1, 1, + }, + }, + { + name: "Three options", + bytes: []byte{ + 255, 0, + 254, 1, 1, + 253, 4, 2, 3, 4, 5, + }, + }, + { + name: "Single unknown only identifier", + bytes: []byte{255}, + err: io.ErrUnexpectedEOF, + }, + { + name: "Single unknown too small with length = 1", + bytes: []byte{255, 1}, + err: io.ErrUnexpectedEOF, + }, + { + name: "Single unknown too small with length = 2", + bytes: []byte{255, 2, 1}, + err: io.ErrUnexpectedEOF, + }, + { + name: "Valid first with second unknown only identifier", + bytes: []byte{ + 255, 0, + 254, + }, + err: io.ErrUnexpectedEOF, + }, + { + name: "Valid first with second unknown missing data", + bytes: []byte{ + 255, 0, + 254, 1, + }, + err: io.ErrUnexpectedEOF, + }, + { + name: "Valid first with second unknown too small", + bytes: []byte{ + 255, 0, + 254, 2, 1, + }, + err: io.ErrUnexpectedEOF, + }, + { + name: "One Pad1", + bytes: []byte{0}, + }, + { + name: "Multiple Pad1", + bytes: []byte{0, 0, 0}, + }, + { + name: "Multiple PadN", + bytes: []byte{ + // Pad3 + 1, 1, 1, + + // Pad5 + 1, 3, 1, 2, 3, + }, + }, + { + name: "Pad5 too small middle of data buffer", + bytes: []byte{1, 3, 1, 2}, + err: io.ErrUnexpectedEOF, + }, + { + name: "Pad5 no data", + bytes: []byte{1, 3}, + err: io.ErrUnexpectedEOF, + }, + } + + check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) { + for i := 0; ; i++ { + _, done, err := it.Next() + if err != nil { + // If we encountered a non-nil error while iterating, make sure it is + // is the same error as expectedErr. + if !errors.Is(err, expectedErr) { + t.Fatalf("got %d-th Next() = %v, want = %v", i, err, expectedErr) + } + + return + } + if done { + // If we are done (without an error), make sure that we did not expect + // an error. + if expectedErr != nil { + t.Fatalf("expected error when iterating; want = %s", expectedErr) + } + + return + } + } + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Run("Hop By Hop", func(t *testing.T) { + extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} + check(t, extHdr.Iter(), test.err) + }) + + t.Run("Destination", func(t *testing.T) { + extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} + check(t, extHdr.Iter(), test.err) + }) + }) + } +} + +func TestIPv6OptionsExtHdrIter(t *testing.T) { + tests := []struct { + name string + bytes []byte + expected []IPv6ExtHdrOption + }{ + { + name: "Single unknown with zero length", + bytes: []byte{255, 0}, + expected: []IPv6ExtHdrOption{ + &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}}, + }, + }, + { + name: "Single unknown with non-zero length", + bytes: []byte{255, 3, 1, 2, 3}, + expected: []IPv6ExtHdrOption{ + &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{1, 2, 3}}, + }, + }, + { + name: "Single Pad1", + bytes: []byte{0}, + }, + { + name: "Two Pad1", + bytes: []byte{0, 0}, + }, + { + name: "Single Pad3", + bytes: []byte{1, 1, 1}, + }, + { + name: "Single Pad5", + bytes: []byte{1, 3, 1, 2, 3}, + }, + { + name: "Multiple Pad", + bytes: []byte{ + // Pad1 + 0, + + // Pad2 + 1, 0, + + // Pad3 + 1, 1, 1, + + // Pad4 + 1, 2, 1, 2, + + // Pad5 + 1, 3, 1, 2, 3, + }, + }, + { + name: "Multiple options", + bytes: []byte{ + // Pad1 + 0, + + // Unknown + 255, 0, + + // Pad2 + 1, 0, + + // Unknown + 254, 1, 1, + + // Pad3 + 1, 1, 1, + + // Unknown + 253, 4, 2, 3, 4, 5, + + // Pad4 + 1, 2, 1, 2, + }, + expected: []IPv6ExtHdrOption{ + &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}}, + &IPv6UnknownExtHdrOption{Identifier: 254, Data: []byte{1}}, + &IPv6UnknownExtHdrOption{Identifier: 253, Data: []byte{2, 3, 4, 5}}, + }, + }, + } + + checkIter := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expected []IPv6ExtHdrOption) { + for i, e := range expected { + opt, done, err := it.Next() + if err != nil { + t.Errorf("(i=%d) Next(): %s", i, err) + } + if done { + t.Errorf("(i=%d) unexpectedly done iterating", i) + } + if diff := cmp.Diff(e, opt); diff != "" { + t.Errorf("(i=%d) got option mismatch (-want +got):\n%s", i, diff) + } + + if t.Failed() { + t.FailNow() + } + } + + opt, done, err := it.Next() + if err != nil { + t.Errorf("(last) Next(): %s", err) + } + if !done { + t.Errorf("(last) iterator unexpectedly not done") + } + if opt != nil { + t.Errorf("(last) got Next() = %T, want = nil", opt) + } + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Run("Hop By Hop", func(t *testing.T) { + extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} + checkIter(t, extHdr.Iter(), test.expected) + }) + + t.Run("Destination", func(t *testing.T) { + extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} + checkIter(t, extHdr.Iter(), test.expected) + }) + }) + } +} + +func TestIPv6RoutingExtHdr(t *testing.T) { + tests := []struct { + name string + bytes []byte + segmentsLeft uint8 + }{ + { + name: "Zeroes", + bytes: []byte{0, 0, 0, 0, 0, 0}, + segmentsLeft: 0, + }, + { + name: "Ones", + bytes: []byte{1, 1, 1, 1, 1, 1}, + segmentsLeft: 1, + }, + { + name: "Mixed", + bytes: []byte{1, 2, 3, 4, 5, 6}, + segmentsLeft: 2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + extHdr := IPv6RoutingExtHdr(test.bytes) + if got := extHdr.SegmentsLeft(); got != test.segmentsLeft { + t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft) + } + }) + } +} + +func TestIPv6FragmentExtHdr(t *testing.T) { + tests := []struct { + name string + bytes [6]byte + fragmentOffset uint16 + more bool + id uint32 + }{ + { + name: "Zeroes", + bytes: [6]byte{0, 0, 0, 0, 0, 0}, + fragmentOffset: 0, + more: false, + id: 0, + }, + { + name: "Ones", + bytes: [6]byte{0, 9, 0, 0, 0, 1}, + fragmentOffset: 1, + more: true, + id: 1, + }, + { + name: "Mixed", + bytes: [6]byte{68, 9, 128, 4, 2, 1}, + fragmentOffset: 2177, + more: true, + id: 2147746305, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + extHdr := IPv6FragmentExtHdr(test.bytes) + if got := extHdr.FragmentOffset(); got != test.fragmentOffset { + t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset) + } + if got := extHdr.More(); got != test.more { + t.Errorf("got More() = %t, want = %t", got, test.more) + } + if got := extHdr.ID(); got != test.id { + t.Errorf("got ID() = %d, want = %d", got, test.id) + } + }) + } +} + +func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView { + size := 0 + var vs []buffer.View + + for _, b := range bs { + vs = append(vs, buffer.View(b)) + size += len(b) + } + + return buffer.NewVectorisedView(size, vs) +} + +func TestIPv6ExtHdrIterErr(t *testing.T) { + tests := []struct { + name string + firstNextHdr IPv6ExtensionHeaderIdentifier + payload buffer.VectorisedView + err error + }{ + { + name: "Upper layer only without data", + firstNextHdr: 255, + }, + { + name: "Upper layer only with data", + firstNextHdr: 255, + payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}), + }, + { + name: "No next header", + firstNextHdr: IPv6NoNextHeaderIdentifier, + }, + { + name: "No next header with data", + firstNextHdr: IPv6NoNextHeaderIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}), + }, + { + name: "Valid single hop by hop", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}), + }, + { + name: "Hop by hop too small", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}), + err: io.ErrUnexpectedEOF, + }, + { + name: "Valid single fragment", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}), + }, + { + name: "Fragment too small", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}), + err: io.ErrUnexpectedEOF, + }, + { + name: "Valid single destination", + firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}), + }, + { + name: "Destination too small", + firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}), + err: io.ErrUnexpectedEOF, + }, + { + name: "Valid single routing", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}), + }, + { + name: "Valid single routing across views", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}), + }, + { + name: "Routing too small with zero length field", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}), + err: io.ErrUnexpectedEOF, + }, + { + name: "Valid routing with non-zero length field", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}), + }, + { + name: "Valid routing with non-zero length field across views", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}), + }, + { + name: "Routing too small with non-zero length field", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}), + err: io.ErrUnexpectedEOF, + }, + { + name: "Routing too small with non-zero length field across views", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}), + err: io.ErrUnexpectedEOF, + }, + { + name: "Mixed", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Hop By Hop Options extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, + + // (Atomic) Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, + + // Routing extension header. + uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Destination Options extension header. + 255, 0, 255, 4, 1, 2, 3, 4, + + // Upper layer data. + 1, 2, 3, 4, + }), + }, + { + name: "Mixed without upper layer data", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Hop By Hop Options extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, + + // (Atomic) Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, + + // Routing extension header. + uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Destination Options extension header. + 255, 0, 255, 4, 1, 2, 3, 4, + }), + }, + { + name: "Mixed without upper layer data but last ext hdr too small", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Hop By Hop Options extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, + + // (Atomic) Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, + + // Routing extension header. + uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Destination Options extension header. + 255, 0, 255, 4, 1, 2, 3, + }), + err: io.ErrUnexpectedEOF, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload) + + for i := 0; ; i++ { + _, done, err := it.Next() + if err != nil { + // If we encountered a non-nil error while iterating, make sure it is + // is the same error as test.err. + if !errors.Is(err, test.err) { + t.Fatalf("got %d-th Next() = %v, want = %v", i, err, test.err) + } + + return + } + if done { + // If we are done (without an error), make sure that we did not expect + // an error. + if test.err != nil { + t.Fatalf("expected error when iterating; want = %s", test.err) + } + + return + } + } + }) + } +} + +func TestIPv6ExtHdrIter(t *testing.T) { + routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4}) + upperLayerData := buffer.View([]byte{1, 2, 3, 4}) + tests := []struct { + name string + firstNextHdr IPv6ExtensionHeaderIdentifier + payload buffer.VectorisedView + expected []IPv6PayloadHeader + }{ + // With a non-atomic fragment that is not the first fragment, the payload + // after the fragment will not be parsed because the payload is expected to + // only hold upper layer data. + { + name: "hopbyhop - fragment (not first) - routing - upper", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Hop By Hop extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, + + // Fragment extension header. + // + // More = 1, Fragment Offset = 2117, ID = 2147746305 + uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, + + // Routing extension header. + // + // Even though we have a routing ext header here, it should be + // be interpretted as raw bytes as only the first fragment is expected + // to hold headers. + 255, 0, 1, 2, 3, 4, 5, 6, + + // Upper layer data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, + IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}), + IPv6RawPayloadHeader{ + Identifier: IPv6RoutingExtHdrIdentifier, + Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(), + }, + }, + }, + { + name: "hopbyhop - fragment (first) - routing - upper", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Hop By Hop extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, + + // Fragment extension header. + // + // More = 1, Fragment Offset = 0, ID = 2147746305 + uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2, 3, 4, 5, 6, + + // Upper layer data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, + IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}), + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6RawPayloadHeader{ + Identifier: 255, + Buf: upperLayerData.ToVectorisedView(), + }, + }, + }, + { + name: "fragment - routing - upper (across views)", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2}, []byte{3, 4, 5, 6, + + // Upper layer data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}), + IPv6RawPayloadHeader{ + Identifier: IPv6RoutingExtHdrIdentifier, + Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(), + }, + }, + }, + + // If we have an atomic fragment, the payload following the fragment + // extension header should be parsed normally. + { + name: "atomic fragment - routing - destination - upper", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, + + // Routing extension header. + uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Destination Options extension header. + 255, 0, 1, 4, 1, 2, 3, 4, + + // Upper layer data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, + IPv6RawPayloadHeader{ + Identifier: 255, + Buf: upperLayerData.ToVectorisedView(), + }, + }, + }, + { + name: "atomic fragment - routing - upper (across views)", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2}, []byte{3, 4, 5, 6, + + // Upper layer data. + 1, 2}, []byte{3, 4}), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6RawPayloadHeader{ + Identifier: 255, + Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), + }, + }, + }, + { + name: "atomic fragment - destination - no next header", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + // + // Res (Reserved) bits are 1 which should not affect anything. + uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 0, 6, 128, 4, 2, 1, + + // Destination Options extension header. + uint8(IPv6NoNextHeaderIdentifier), 0, 1, 4, 1, 2, 3, 4, + + // Random data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, + }, + }, + { + name: "routing - atomic fragment - no next header", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Routing extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1, + + // Random data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + }, + }, + { + name: "routing - atomic fragment - no next header (across views)", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Routing extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1, + + // Random data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + }, + }, + { + name: "hopbyhop - routing - fragment - no next header", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Hop By Hop Options extension header. + uint8(IPv6RoutingExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, + + // Routing extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Fragment extension header. + // + // Fragment Offset = 32; Res = 6. + uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1, + + // Random data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}), + IPv6RawPayloadHeader{ + Identifier: IPv6NoNextHeaderIdentifier, + Buf: upperLayerData.ToVectorisedView(), + }, + }, + }, + + // Test the raw payload for common transport layer protocol numbers. + { + name: "TCP raw payload", + firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber), + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber), + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "UDP raw payload", + firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber), + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber), + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "ICMPv4 raw payload", + firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber), + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber), + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "ICMPv6 raw payload", + firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber), + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber), + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "Unknwon next header raw payload", + firstNextHdr: 255, + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: 255, + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "Unknwon next header raw payload (across views)", + firstNextHdr: 255, + payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: 255, + Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), + }}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload) + + for i, e := range test.expected { + extHdr, done, err := it.Next() + if err != nil { + t.Errorf("(i=%d) Next(): %s", i, err) + } + if done { + t.Errorf("(i=%d) unexpectedly done iterating", i) + } + if diff := cmp.Diff(e, extHdr); diff != "" { + t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff) + } + + if t.Failed() { + t.FailNow() + } + } + + extHdr, done, err := it.Next() + if err != nil { + t.Errorf("(last) Next(): %s", err) + } + if !done { + t.Errorf("(last) iterator unexpectedly not done") + } + if extHdr != nil { + t.Errorf("(last) got Next() = %T, want = nil", extHdr) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go new file mode 100644 index 000000000..426a873b1 --- /dev/null +++ b/pkg/tcpip/header/ipv6_test.go @@ -0,0 +1,417 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "bytes" + "crypto/sha256" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") +) + +func TestEthernetAdddressToModifiedEUI64(t *testing.T) { + expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} + + if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) + } + + var buf [header.IIDSize]byte + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + if diff := cmp.Diff(expectedIID, buf); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) + } +} + +func TestLinkLocalAddr(t *testing.T) { + if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) + } +} + +func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + prefix: header.IPv6LinkLocalPrefix.Subnet(), + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey and empty nicName", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h := sha256.New() + h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address])) + h.Write([]byte(test.nicName)) + h.Write([]byte{test.dadCounter}) + if k := test.secretKey; k != nil { + h.Write(k) + } + var hashSum [sha256.Size]byte + h.Sum(hashSum[:0]) + want := hashSum[:header.IIDSize] + + // Passing a nil buffer should result in a new buffer returned with the + // IID. + if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + + // Passing a buffer with sufficient capacity for the IID should populate + // the buffer provided. + var iidBuf [header.IIDSize]byte + if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + if got := iidBuf[:]; !bytes.Equal(got, want) { + t.Errorf("got iidBuf = %x, want = %x", got, want) + } + }) + } +} + +func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + prefix := header.IPv6LinkLocalPrefix.Subnet() + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey and empty nicName", + nicName: "", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + addrBytes := [header.IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + test.nicName, + test.dadCounter, + test.secretKey, + )) + + if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want { + t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want) + } + }) + } +} + +func TestIsV6UniqueLocalAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Unique 1", + addr: uniqueLocalAddr1, + expected: true, + }, + { + name: "Valid Unique 2", + addr: uniqueLocalAddr1, + expected: true, + }, + { + name: "Link Local", + addr: linkLocalAddr, + expected: false, + }, + { + name: "Global", + addr: globalAddr, + expected: false, + }, + { + name: "IPv4", + addr: "\x01\x02\x03\x04", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV6LinkLocalMulticastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Link Local Multicast", + addr: linkLocalMulticastAddr, + expected: true, + }, + { + name: "Valid Link Local Multicast with flags", + addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + expected: true, + }, + { + name: "Link Local Unicast", + addr: linkLocalAddr, + expected: false, + }, + { + name: "IPv4 Multicast", + addr: "\xe0\x00\x00\x01", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV6LinkLocalAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Link Local Unicast", + addr: linkLocalAddr, + expected: true, + }, + { + name: "Link Local Multicast", + addr: linkLocalMulticastAddr, + expected: false, + }, + { + name: "Unique Local", + addr: uniqueLocalAddr1, + expected: false, + }, + { + name: "Global", + addr: globalAddr, + expected: false, + }, + { + name: "IPv4 Link Local", + addr: "\xa9\xfe\x00\x01", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestScopeForIPv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + scope header.IPv6AddressScope + err *tcpip.Error + }{ + { + name: "Unique Local", + addr: uniqueLocalAddr1, + scope: header.UniqueLocalScope, + err: nil, + }, + { + name: "Link Local Unicast", + addr: linkLocalAddr, + scope: header.LinkLocalScope, + err: nil, + }, + { + name: "Link Local Multicast", + addr: linkLocalMulticastAddr, + scope: header.LinkLocalScope, + err: nil, + }, + { + name: "Global", + addr: globalAddr, + scope: header.GlobalScope, + err: nil, + }, + { + name: "IPv4", + addr: "\x01\x02\x03\x04", + scope: header.GlobalScope, + err: tcpip.ErrBadAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := header.ScopeForIPv6Address(test.addr) + if err != test.err { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err) + } + if got != test.scope { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) + } + }) + } +} + +func TestSolicitedNodeAddr(t *testing.T) { + tests := []struct { + addr tcpip.Address + want tcpip.Address + }{ + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", + }, + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", + }, + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03", + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { + if got := header.SolicitedNodeAddr(test.addr); got != test.want { + t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want) + } + }) + } +} diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 98310ea23..5d3975c56 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -15,24 +15,46 @@ package header import ( + "bytes" "encoding/binary" + "errors" + "fmt" + "io" + "math" "time" "gvisor.dev/gvisor/pkg/tcpip" ) +// NDPOptionIdentifier is an NDP option type identifier. +type NDPOptionIdentifier uint8 + const ( - // NDPTargetLinkLayerAddressOptionType is the type of the Target - // Link-Layer Address option, as per RFC 4861 section 4.6.1. - NDPTargetLinkLayerAddressOptionType = 2 + // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer + // Address option, as per RFC 4861 section 4.6.1. + NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1 - // ndpTargetEthernetLinkLayerAddressSize is the size of a Target - // Link Layer Option for an Ethernet address. - ndpTargetEthernetLinkLayerAddressSize = 8 + // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer + // Address option, as per RFC 4861 section 4.6.1. + NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2 - // ndpPrefixInformationType is the type of the Prefix Information + // NDPPrefixInformationType is the type of the Prefix Information // option, as per RFC 4861 section 4.6.2. - ndpPrefixInformationType = 3 + NDPPrefixInformationType NDPOptionIdentifier = 3 + + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25 + + // NDPDNSSearchListOptionType is the type of the DNS Search List option, + // as per RFC 8106 section 5.2. + NDPDNSSearchListOptionType = 31 +) + +const ( + // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer + // Address option for an Ethernet address. + NDPLinkLayerAddressSize = 8 // ndpPrefixInformationLength is the expected length, in bytes, of the // body of an NDP Prefix Information option, as per RFC 4861 section @@ -84,10 +106,39 @@ const ( // within an NDPPrefixInformation. ndpPrefixInformationPrefixOffset = 14 - // NDPPrefixInformationInfiniteLifetime is a value that represents - // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix - // Information option. Its value is (2^32 - 1)s = 4294967295s - NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295 + // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerLifetimeOffset = 2 + + // ndpRecursiveDNSServerAddressesOffset is the start of the addresses + // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerAddressesOffset = 6 + + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS Server + // option's body size when it contains at least one IPv6 address, as per + // RFC 8106 section 5.3.1. + minNDPRecursiveDNSServerBodySize = 22 + + // ndpDNSSearchListLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPDNSSearchList. + ndpDNSSearchListLifetimeOffset = 2 + + // ndpDNSSearchListDomainNamesOffset is the start of the DNS search list + // domain names within an NDPDNSSearchList. + ndpDNSSearchListDomainNamesOffset = 6 + + // minNDPDNSSearchListBodySize is the minimum NDP DNS Search List option's + // body size when it contains at least one domain name, as per RFC 8106 + // section 5.3.1. + minNDPDNSSearchListBodySize = 14 + + // maxDomainNameLabelLength is the maximum length of a domain name + // label, as per RFC 1035 section 3.1. + maxDomainNameLabelLength = 63 + + // maxDomainNameLength is the maximum length of a domain name, including + // label AND label length octet, as per RFC 1035 section 3.1. + maxDomainNameLength = 255 // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of @@ -95,9 +146,158 @@ const ( lengthByteUnits = 8 ) +var ( + // NDPInfiniteLifetime is a value that represents infinity for the + // 4-byte lifetime fields found in various NDP options. Its value is + // (2^32 - 1)s = 4294967295s. + // + // This is a variable instead of a constant so that tests can change + // this value to a smaller value. It should only be modified by tests. + NDPInfiniteLifetime = time.Second * math.MaxUint32 +) + +// NDPOptionIterator is an iterator of NDPOption. +// +// Note, between when an NDPOptionIterator is obtained and last used, no changes +// to the NDPOptions may happen. Doing so may cause undefined and unexpected +// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first +// few NDPOption then modify the backing NDPOptions so long as the +// NDPOptionIterator obtained before modification is no longer used. +type NDPOptionIterator struct { + opts *bytes.Buffer +} + +// Potential errors when iterating over an NDPOptions. +var ( + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPOptMalformedHeader = errors.New("NDP option has a malformed header") +) + +// Next returns the next element in the backing NDPOptions, or true if we are +// done, or false if an error occured. +// +// The return can be read as option, done, error. Note, option should only be +// used if done is false and error is nil. +func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { + for { + // Do we still have elements to look at? + if i.opts.Len() == 0 { + return nil, true, nil + } + + // Get the Type field. + temp, err := i.opts.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading the option's Type field: %s", err)) + } + + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once + // we start parsing an option; we expect the buffer to contain enough + // bytes for the whole option. + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF) + } + kind := NDPOptionIdentifier(temp) + + // Get the Length field. + length, err := i.opts.ReadByte() + if err != nil { + if err != io.EOF { + panic(fmt.Sprintf("unexpected error when reading the option's Length field for %s: %s", kind, err)) + } + + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Length field for %s: %w", kind, io.ErrUnexpectedEOF) + } + + // This would indicate an erroneous NDP option as the Length field should + // never be 0. + if length == 0 { + return nil, true, fmt.Errorf("zero valued Length field for %s: %w", kind, ErrNDPOptMalformedHeader) + } + + // Get the body. + numBytes := int(length) * lengthByteUnits + numBodyBytes := numBytes - 2 + body := i.opts.Next(numBodyBytes) + if len(body) < numBodyBytes { + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Body for %s: %w", kind, io.ErrUnexpectedEOF) + } + + switch kind { + case NDPSourceLinkLayerAddressOptionType: + return NDPSourceLinkLayerAddressOption(body), false, nil + + case NDPTargetLinkLayerAddressOptionType: + return NDPTargetLinkLayerAddressOption(body), false, nil + + case NDPPrefixInformationType: + // Make sure the length of a Prefix Information option + // body is ndpPrefixInformationLength, as per RFC 4861 + // section 4.6.2. + if numBodyBytes != ndpPrefixInformationLength { + return nil, true, fmt.Errorf("got %d bytes for NDP Prefix Information option's body, expected %d bytes: %w", numBodyBytes, ndpPrefixInformationLength, ErrNDPOptMalformedBody) + } + + return NDPPrefixInformation(body), false, nil + + case NDPRecursiveDNSServerOptionType: + opt := NDPRecursiveDNSServer(body) + if err := opt.checkAddresses(); err != nil { + return nil, true, err + } + + return opt, false, nil + + case NDPDNSSearchListOptionType: + opt := NDPDNSSearchList(body) + if err := opt.checkDomainNames(); err != nil { + return nil, true, err + } + + return opt, false, nil + + default: + // We do not yet recognize the option, just skip for + // now. This is okay because RFC 4861 allows us to + // skip/ignore any unrecognized options. However, + // we MUST recognized all the options in RFC 4861. + // + // TODO(b/141487990): Handle all NDP options as defined + // by RFC 4861. + } + } +} + // NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6. type NDPOptions []byte +// Iter returns an iterator of NDPOption. +// +// If check is true, Iter will do an integrity check on the options by iterating +// over it and returning an error if detected. +// +// See NDPOptionIterator for more information. +func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { + it := NDPOptionIterator{ + opts: bytes.NewBuffer(b), + } + + if check { + it2 := NDPOptionIterator{ + opts: bytes.NewBuffer(b), + } + + for { + if _, done, err := it2.Next(); err != nil || done { + return it, err + } + } + } + + return it, nil +} + // Serialize serializes the provided list of NDP options into o. // // Note, b must be of sufficient size to hold all the options in s. See @@ -116,7 +316,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { continue } - b[0] = o.Type() + b[0] = byte(o.Type()) // We know this safe because paddedLength would have returned // 0 if o had an invalid length (> 255 * lengthByteUnits). @@ -137,15 +337,17 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { return done } -// ndpOption is the set of functions to be implemented by all NDP option types. -type ndpOption interface { - // Type returns the type of this ndpOption. - Type() uint8 +// NDPOption is the set of functions to be implemented by all NDP option types. +type NDPOption interface { + fmt.Stringer + + // Type returns the type of the receiver. + Type() NDPOptionIdentifier - // Length returns the length of the body of this ndpOption, in bytes. + // Length returns the length of the body of the receiver, in bytes. Length() int - // serializeInto serializes this ndpOption into the provided byte + // serializeInto serializes the receiver into the provided byte // buffer. // // Note, the caller MUST provide a byte buffer with size of at least @@ -154,15 +356,15 @@ type ndpOption interface { // buffer is not of sufficient size. // // serializeInto will return the number of bytes that was used to - // serialize this ndpOption. Implementers must only use the number of - // bytes required to serialize this ndpOption. Callers MAY provide a + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a // larger buffer than required to serialize into. serializeInto([]byte) int } // paddedLength returns the length of o, in bytes, with any padding bytes, if // required. -func paddedLength(o ndpOption) int { +func paddedLength(o NDPOption) int { l := o.Length() if l == 0 { @@ -201,7 +403,7 @@ func paddedLength(o ndpOption) int { } // NDPOptionsSerializer is a serializer for NDP options. -type NDPOptionsSerializer []ndpOption +type NDPOptionsSerializer []NDPOption // Length returns the total number of bytes required to serialize. func (b NDPOptionsSerializer) Length() int { @@ -214,6 +416,46 @@ func (b NDPOptionsSerializer) Length() int { return l } +// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option +// as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPSourceLinkLayerAddressOption tcpip.LinkAddress + +// Type implements NDPOption.Type. +func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier { + return NDPSourceLinkLayerAddressOptionType +} + +// Length implements NDPOption.Length. +func (o NDPSourceLinkLayerAddressOption) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer.String. +func (o NDPSourceLinkLayerAddressOption) String() string { + return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) +} + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPSourceLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + // NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option // as defined by RFC 4861 section 4.6.1. // @@ -221,21 +463,39 @@ func (b NDPOptionsSerializer) Length() int { // where X is the value in Length multiplied by lengthByteUnits - 2 bytes. type NDPTargetLinkLayerAddressOption tcpip.LinkAddress -// Type implements ndpOption.Type. -func (o NDPTargetLinkLayerAddressOption) Type() uint8 { +// Type implements NDPOption.Type. +func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier { return NDPTargetLinkLayerAddressOptionType } -// Length implements ndpOption.Length. +// Length implements NDPOption.Length. func (o NDPTargetLinkLayerAddressOption) Length() int { return len(o) } -// serializeInto implements ndpOption.serializeInto. +// serializeInto implements NDPOption.serializeInto. func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } +// String implements fmt.Stringer.String. +func (o NDPTargetLinkLayerAddressOption) String() string { + return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) +} + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + // NDPPrefixInformation is the NDP Prefix Information option as defined by // RFC 4861 section 4.6.2. // @@ -243,17 +503,17 @@ func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { // ndpPrefixInformationLength bytes. type NDPPrefixInformation []byte -// Type implements ndpOption.Type. -func (o NDPPrefixInformation) Type() uint8 { - return ndpPrefixInformationType +// Type implements NDPOption.Type. +func (o NDPPrefixInformation) Type() NDPOptionIdentifier { + return NDPPrefixInformationType } -// Length implements ndpOption.Length. +// Length implements NDPOption.Length. func (o NDPPrefixInformation) Length() int { return ndpPrefixInformationLength } -// serializeInto implements ndpOption.serializeInto. +// serializeInto implements NDPOption.serializeInto. func (o NDPPrefixInformation) serializeInto(b []byte) int { used := copy(b, o) @@ -269,6 +529,17 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int { return used } +// String implements fmt.Stringer.String. +func (o NDPPrefixInformation) String() string { + return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)", + o, + o.OnLinkFlag(), + o.AutonomousAddressConfigurationFlag(), + o.PreferredLifetime(), + o.ValidLifetime(), + o.Subnet()) +} + // PrefixLength returns the value in the number of leading bits in the Prefix // that are valid. // @@ -302,7 +573,7 @@ func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { // // Note, a value of 0 implies the prefix should not be considered as on-link, // and a value of infinity/forever is represented by -// NDPPrefixInformationInfiniteLifetime. +// NDPInfiniteLifetime. func (o NDPPrefixInformation) ValidLifetime() time.Duration { // The field is the time in seconds, as per RFC 4861 section 4.6.2. return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) @@ -315,7 +586,7 @@ func (o NDPPrefixInformation) ValidLifetime() time.Duration { // // Note, a value of 0 implies that addresses generated from the prefix should // no longer remain preferred, and a value of infinity is represented by -// NDPPrefixInformationInfiniteLifetime. +// NDPInfiniteLifetime. // // Also note that the value of this field MUST NOT exceed the Valid Lifetime // field to avoid preferring addresses that are no longer valid, for the @@ -334,3 +605,295 @@ func (o NDPPrefixInformation) PreferredLifetime() time.Duration { func (o NDPPrefixInformation) Prefix() tcpip.Address { return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize]) } + +// Subnet returns the Prefix field and Prefix Length field represented in a +// tcpip.Subnet. +func (o NDPPrefixInformation) Subnet() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: o.Prefix(), + PrefixLen: int(o.PrefixLength()), + } + return addrWithPrefix.Subnet() +} + +// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by +// RFC 8106 section 5.1. +// +// To make sure that the option meets its minimum length and does not end in the +// middle of a DNS server's IPv6 address, the length of a valid +// NDPRecursiveDNSServer must meet the following constraint: +// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0 +type NDPRecursiveDNSServer []byte + +// Type returns the type of an NDP Recursive DNS Server option. +// +// Type implements NDPOption.Type. +func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier { + return NDPRecursiveDNSServerOptionType +} + +// Length implements NDPOption.Length. +func (o NDPRecursiveDNSServer) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// String implements fmt.Stringer.String. +func (o NDPRecursiveDNSServer) String() string { + lt := o.Lifetime() + addrs, err := o.Addresses() + if err != nil { + return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err) + } + return fmt.Sprintf("%T(%s valid for %s)", o, addrs, lt) +} + +// Lifetime returns the length of time that the DNS server addresses +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the addresses should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +// +// Lifetime may panic if o does not have enough bytes to hold the Lifetime +// field. +func (o NDPRecursiveDNSServer) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:])) +} + +// Addresses returns the recursive DNS server IPv6 addresses that may be +// used for name resolution. +// +// Note, the addresses MAY be link-local addresses. +func (o NDPRecursiveDNSServer) Addresses() ([]tcpip.Address, error) { + var addrs []tcpip.Address + return addrs, o.iterAddresses(func(addr tcpip.Address) { addrs = append(addrs, addr) }) +} + +// checkAddresses iterates over the addresses in an NDP Recursive DNS Server +// option and returns any error it encounters. +func (o NDPRecursiveDNSServer) checkAddresses() error { + return o.iterAddresses(nil) +} + +// iterAddresses iterates over the addresses in an NDP Recursive DNS Server +// option and calls a function with each valid unicast IPv6 address. +// +// Note, the addresses MAY be link-local addresses. +func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error { + if l := len(o); l < minNDPRecursiveDNSServerBodySize { + return fmt.Errorf("got %d bytes for NDP Recursive DNS Server option's body, expected at least %d bytes: %w", l, minNDPRecursiveDNSServerBodySize, io.ErrUnexpectedEOF) + } + + o = o[ndpRecursiveDNSServerAddressesOffset:] + l := len(o) + if l%IPv6AddressSize != 0 { + return fmt.Errorf("NDP Recursive DNS Server option's body ends in the middle of an IPv6 address (addresses body size = %d bytes): %w", l, ErrNDPOptMalformedBody) + } + + for i := 0; len(o) != 0; i++ { + addr := tcpip.Address(o[:IPv6AddressSize]) + if !IsV6UnicastAddress(addr) { + return fmt.Errorf("%d-th address (%s) in NDP Recursive DNS Server option is not a valid unicast IPv6 address: %w", i, addr, ErrNDPOptMalformedBody) + } + + if fn != nil { + fn(addr) + } + + o = o[IPv6AddressSize:] + } + + return nil +} + +// NDPDNSSearchList is the NDP DNS Search List option, as defined by +// RFC 8106 section 5.2. +type NDPDNSSearchList []byte + +// Type implements NDPOption.Type. +func (o NDPDNSSearchList) Type() NDPOptionIdentifier { + return NDPDNSSearchListOptionType +} + +// Length implements NDPOption.Length. +func (o NDPDNSSearchList) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPDNSSearchList) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpDNSSearchListLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// String implements fmt.Stringer.String. +func (o NDPDNSSearchList) String() string { + lt := o.Lifetime() + domainNames, err := o.DomainNames() + if err != nil { + return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err) + } + return fmt.Sprintf("%T(%s valid for %s)", o, domainNames, lt) +} + +// Lifetime returns the length of time that the DNS search list of domain names +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the domain names should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +func (o NDPDNSSearchList) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpDNSSearchListLifetimeOffset:])) +} + +// DomainNames returns a DNS search list of domain names. +// +// DomainNames will parse the backing buffer as outlined by RFC 1035 section +// 3.1 and return a list of strings, with all domain names in lower case. +func (o NDPDNSSearchList) DomainNames() ([]string, error) { + var domainNames []string + return domainNames, o.iterDomainNames(func(domainName string) { domainNames = append(domainNames, domainName) }) +} + +// checkDomainNames iterates over the domain names in an NDP DNS Search List +// option and returns any error it encounters. +func (o NDPDNSSearchList) checkDomainNames() error { + return o.iterDomainNames(nil) +} + +// iterDomainNames iterates over the domain names in an NDP DNS Search List +// option and calls a function with each valid domain name. +func (o NDPDNSSearchList) iterDomainNames(fn func(string)) error { + if l := len(o); l < minNDPDNSSearchListBodySize { + return fmt.Errorf("got %d bytes for NDP DNS Search List option's body, expected at least %d bytes: %w", l, minNDPDNSSearchListBodySize, io.ErrUnexpectedEOF) + } + + var searchList bytes.Reader + searchList.Reset(o[ndpDNSSearchListDomainNamesOffset:]) + + var scratch [maxDomainNameLength]byte + domainName := bytes.NewBuffer(scratch[:]) + + // Parse the domain names, as per RFC 1035 section 3.1. + for searchList.Len() != 0 { + domainName.Reset() + + // Parse a label within a domain name, as per RFC 1035 section 3.1. + for { + // The first byte is the label length. + labelLenByte, err := searchList.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading a label's length: %s", err)) + } + + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected + // once we start parsing a domain name; we expect the buffer to contain + // enough bytes for the whole domain name. + return fmt.Errorf("unexpected exhausted buffer while parsing a new label for a domain from NDP Search List option: %w", io.ErrUnexpectedEOF) + } + labelLen := int(labelLenByte) + + // A zero-length label implies the end of a domain name. + if labelLen == 0 { + // If the domain name is empty or we have no callback function, do + // nothing further with the current domain name. + if domainName.Len() == 0 || fn == nil { + break + } + + // Ignore the trailing period in the parsed domain name. + domainName.Truncate(domainName.Len() - 1) + fn(domainName.String()) + break + } + + // The label's length must not exceed the maximum length for a label. + if labelLen > maxDomainNameLabelLength { + return fmt.Errorf("label length of %d bytes is greater than the max label length of %d bytes for an NDP Search List option: %w", labelLen, maxDomainNameLabelLength, ErrNDPOptMalformedBody) + } + + // The label (and trailing period) must not make the domain name too long. + if labelLen+1 > domainName.Cap()-domainName.Len() { + return fmt.Errorf("label would make an NDP Search List option's domain name longer than the max domain name length of %d bytes: %w", maxDomainNameLength, ErrNDPOptMalformedBody) + } + + // Copy the label and add a trailing period. + for i := 0; i < labelLen; i++ { + b, err := searchList.ReadByte() + if err != nil { + if err != io.EOF { + panic(fmt.Sprintf("unexpected error when reading domain name's label: %s", err)) + } + + return fmt.Errorf("read %d out of %d bytes for a domain name's label from NDP Search List option: %w", i, labelLen, io.ErrUnexpectedEOF) + } + + // As per RFC 1035 section 2.3.1: + // 1) the label must only contain ASCII include letters, digits and + // hyphens + // 2) the first character in a label must be a letter + // 3) the last letter in a label must be a letter or digit + + if !isLetter(b) { + if i == 0 { + return fmt.Errorf("first character of a domain name's label in an NDP Search List option must be a letter, got character code = %d: %w", b, ErrNDPOptMalformedBody) + } + + if b == '-' { + if i == labelLen-1 { + return fmt.Errorf("last character of a domain name's label in an NDP Search List option must not be a hyphen (-): %w", ErrNDPOptMalformedBody) + } + } else if !isDigit(b) { + return fmt.Errorf("domain name's label in an NDP Search List option may only contain letters, digits and hyphens, got character code = %d: %w", b, ErrNDPOptMalformedBody) + } + } + + // If b is an upper case character, make it lower case. + if isUpperLetter(b) { + b = b - 'A' + 'a' + } + + if err := domainName.WriteByte(b); err != nil { + panic(fmt.Sprintf("unexpected error writing label to domain name buffer: %s", err)) + } + } + if err := domainName.WriteByte('.'); err != nil { + panic(fmt.Sprintf("unexpected error writing trailing period to domain name buffer: %s", err)) + } + } + } + + return nil +} + +func isLetter(b byte) bool { + return b >= 'a' && b <= 'z' || isUpperLetter(b) +} + +func isUpperLetter(b byte) bool { + return b >= 'A' && b <= 'Z' +} + +func isDigit(b byte) bool { + return b >= '0' && b <= '9' +} diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go new file mode 100644 index 000000000..9e67ba95d --- /dev/null +++ b/pkg/tcpip/header/ndp_router_solicit.go @@ -0,0 +1,36 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.1 for more details. +type NDPRouterSolicit []byte + +const ( + // NDPRSMinimumSize is the minimum size of a valid NDP Router + // Solicitation message (body of an ICMPv6 packet). + NDPRSMinimumSize = 4 + + // ndpRSOptionsOffset is the start of the NDP options in an + // NDPRouterSolicit. + ndpRSOptionsOffset = 4 +) + +// Options returns an NDPOptions of the the options body. +func (b NDPRouterSolicit) Options() NDPOptions { + return NDPOptions(b[ndpRSOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 0bbf67a2b..dc4591253 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -16,9 +16,14 @@ package header import ( "bytes" + "errors" + "fmt" + "io" + "regexp" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -36,18 +41,18 @@ func TestNDPNeighborSolicit(t *testing.T) { ns := NDPNeighborSolicit(b) addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") if got := ns.TargetAddress(); got != addr { - t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr) + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) } // Test updating the Target Address. addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") ns.SetTargetAddress(addr2) if got := ns.TargetAddress(); got != addr2 { - t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr2) + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) } // Make sure the address got updated in the backing buffer. if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { - t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2) + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) } } @@ -65,56 +70,56 @@ func TestNDPNeighborAdvert(t *testing.T) { na := NDPNeighborAdvert(b) addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") if got := na.TargetAddress(); got != addr { - t.Fatalf("got TargetAddress = %s, want %s", got, addr) + t.Errorf("got TargetAddress = %s, want %s", got, addr) } // Test getting the Router Flag. if got := na.RouterFlag(); !got { - t.Fatalf("got RouterFlag = false, want = true") + t.Errorf("got RouterFlag = false, want = true") } // Test getting the Solicited Flag. if got := na.SolicitedFlag(); got { - t.Fatalf("got SolicitedFlag = true, want = false") + t.Errorf("got SolicitedFlag = true, want = false") } // Test getting the Override Flag. if got := na.OverrideFlag(); !got { - t.Fatalf("got OverrideFlag = false, want = true") + t.Errorf("got OverrideFlag = false, want = true") } // Test updating the Target Address. addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") na.SetTargetAddress(addr2) if got := na.TargetAddress(); got != addr2 { - t.Fatalf("got TargetAddress = %s, want %s", got, addr2) + t.Errorf("got TargetAddress = %s, want %s", got, addr2) } // Make sure the address got updated in the backing buffer. if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { - t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2) + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) } // Test updating the Router Flag. na.SetRouterFlag(false) if got := na.RouterFlag(); got { - t.Fatalf("got RouterFlag = true, want = false") + t.Errorf("got RouterFlag = true, want = false") } // Test updating the Solicited Flag. na.SetSolicitedFlag(true) if got := na.SolicitedFlag(); !got { - t.Fatalf("got SolicitedFlag = false, want = true") + t.Errorf("got SolicitedFlag = false, want = true") } // Test updating the Override Flag. na.SetOverrideFlag(false) if got := na.OverrideFlag(); got { - t.Fatalf("got OverrideFlag = true, want = false") + t.Errorf("got OverrideFlag = true, want = false") } // Make sure flags got updated in the backing buffer. if got := b[ndpNAFlagsOffset]; got != 64 { - t.Fatalf("got flags byte = %d, want = 64") + t.Errorf("got flags byte = %d, want = 64", got) } } @@ -128,27 +133,181 @@ func TestNDPRouterAdvert(t *testing.T) { ra := NDPRouterAdvert(b) if got := ra.CurrHopLimit(); got != 64 { - t.Fatalf("got ra.CurrHopLimit = %d, want = 64", got) + t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) } if got := ra.ManagedAddrConfFlag(); !got { - t.Fatalf("got ManagedAddrConfFlag = false, want = true") + t.Errorf("got ManagedAddrConfFlag = false, want = true") } if got := ra.OtherConfFlag(); got { - t.Fatalf("got OtherConfFlag = true, want = false") + t.Errorf("got OtherConfFlag = true, want = false") } if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Fatalf("got ra.RouterLifetime = %d, want = %d", got, want) + t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) } if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Fatalf("got ra.ReachableTime = %d, want = %d", got, want) + t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) } if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Fatalf("got ra.RetransTimer = %d, want = %d", got, want) + t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + } +} + +// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the +// Ethernet address from an NDPSourceLinkLayerAddressOption. +func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) { + tests := []struct { + name string + buf []byte + expected tcpip.LinkAddress + }{ + { + "ValidMAC", + []byte{1, 2, 3, 4, 5, 6}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "SLLBodyTooShort", + []byte{1, 2, 3, 4, 5}, + tcpip.LinkAddress([]byte(nil)), + }, + { + "SLLBodyLargerThanNeeded", + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sll := NDPSourceLinkLayerAddressOption(test.buf) + if got := sll.EthernetAddress(); got != test.expected { + t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected) + } + }) + } +} + +// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a +// NDPSourceLinkLayerAddressOption. +func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) { + tests := []struct { + name string + buf []byte + expectedBuf []byte + addr tcpip.LinkAddress + }{ + { + "Ethernet", + make([]byte, 8), + []byte{1, 1, 1, 2, 3, 4, 5, 6}, + "\x01\x02\x03\x04\x05\x06", + }, + { + "Padding", + []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + "\x01\x02\x03\x04\x05\x06\x07\x08", + }, + { + "Empty", + nil, + nil, + "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + serializer := NDPOptionsSerializer{ + NDPSourceLinkLayerAddressOption(test.addr), + } + if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { + t.Fatalf("got Length = %d, want = %d", got, want) + } + opts.Serialize(serializer) + if !bytes.Equal(test.buf, test.expectedBuf) { + t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + if len(test.expectedBuf) > 0 { + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { + t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) + } + sll := next.(NDPSourceLinkLayerAddressOption) + if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) { + t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + + if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { + t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) + } + } + + // Iterator should not return anything else. + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + +// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the +// Ethernet address from an NDPTargetLinkLayerAddressOption. +func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { + tests := []struct { + name string + buf []byte + expected tcpip.LinkAddress + }{ + { + "ValidMAC", + []byte{1, 2, 3, 4, 5, 6}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "TLLBodyTooShort", + []byte{1, 2, 3, 4, 5}, + tcpip.LinkAddress([]byte(nil)), + }, + { + "TLLBodyLargerThanNeeded", + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tll := NDPTargetLinkLayerAddressOption(test.buf) + if got := tll.EthernetAddress(); got != test.expected { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) + } + }) } } @@ -175,8 +334,8 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { }, { "Empty", - []byte{}, - []byte{}, + nil, + nil, "", }, } @@ -194,6 +353,44 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { if !bytes.Equal(test.buf, test.expectedBuf) { t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + if len(test.expectedBuf) > 0 { + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + tll := next.(NDPTargetLinkLayerAddressOption) + if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { + t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + + if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) + } + } + + // Iterator should not return anything else. + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } }) } } @@ -232,39 +429,1093 @@ func TestNDPPrefixInformationOption(t *testing.T) { t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) } - // First two bytes are the Type and Length fields, which are not part of - // the option body. - pi := NDPPrefixInformation(targetBuf[2:]) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + pi := next.(NDPPrefixInformation) if got := pi.Type(); got != 3 { - t.Fatalf("got Type = %d, want = 3", got) + t.Errorf("got Type = %d, want = 3", got) } if got := pi.Length(); got != 30 { - t.Fatalf("got Length = %d, want = 30", got) + t.Errorf("got Length = %d, want = 30", got) } if got := pi.PrefixLength(); got != 43 { - t.Fatalf("got PrefixLength = %d, want = 43", got) + t.Errorf("got PrefixLength = %d, want = 43", got) } if pi.OnLinkFlag() { - t.Fatalf("got OnLinkFlag = true, want = false") + t.Error("got OnLinkFlag = true, want = false") } if !pi.AutonomousAddressConfigurationFlag() { - t.Fatalf("got AutonomousAddressConfigurationFlag = false, want = true") + t.Error("got AutonomousAddressConfigurationFlag = false, want = true") } if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { - t.Fatalf("got ValidLifetime = %d, want = %d", got, want) + t.Errorf("got ValidLifetime = %d, want = %d", got, want) } if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { - t.Fatalf("got PreferredLifetime = %d, want = %d", got, want) + t.Errorf("got PreferredLifetime = %d, want = %d", got, want) } if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { - t.Fatalf("got Prefix = %s, want = %s", got, want) + t.Errorf("got Prefix = %s, want = %s", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 25, 3, 0, 0, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPRecursiveDNSServer(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Type(); got != 25 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16909320*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + want := []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + } + addrs, err := opt.Addresses() + if err != nil { + t.Errorf("opt.Addresses() = %s", err) + } + if diff := cmp.Diff(addrs, want); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + addrs []tcpip.Address + }{ + { + "Valid1Addr", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + }, + }, + { + "Valid2Addr", + []byte{ + 25, 5, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + }, + }, + { + "Valid3Addr", + []byte{ + 25, 7, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Iterator should get our option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + addrs, err := opt.Addresses() + if err != nil { + t.Errorf("opt.Addresses() = %s", err) + } + if diff := cmp.Diff(addrs, test.addrs); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + +// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList. +func TestNDPDNSSearchListOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + domainNames []string + err error + }{ + { + name: "Valid1Label", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: []string{ + "abc", + }, + err: nil, + }, + { + name: "Valid2Label", + buf: []byte{ + 0, 0, + 0, 0, 0, 5, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 0, + 0, 0, 0, 0, 0, 0, + }, + lifetime: 5 * time.Second, + domainNames: []string{ + "abc.abcd", + }, + err: nil, + }, + { + name: "Valid3Label", + buf: []byte{ + 0, 0, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + }, + lifetime: 16777216 * time.Second, + domainNames: []string{ + "abc.abcd.e", + }, + err: nil, + }, + { + name: "Valid2Domains", + buf: []byte{ + 0, 0, + 1, 2, 3, 4, + 3, 'a', 'b', 'c', + 0, + 2, 'd', 'e', + 3, 'x', 'y', 'z', + 0, + 0, 0, 0, + }, + lifetime: 16909060 * time.Second, + domainNames: []string{ + "abc", + "de.xyz", + }, + err: nil, + }, + { + name: "Valid3DomainsMixedCase", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 'B', 'c', + 0, + 2, 'd', 'E', + 3, 'X', 'y', 'z', + 0, + 1, 'J', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abc", + "de.xyz", + "j", + }, + err: nil, + }, + { + name: "ValidDomainAfterNULL", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 'B', 'c', + 0, 0, 0, 0, + 2, 'd', 'E', + 3, 'X', 'y', 'z', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abc", + "de.xyz", + }, + err: nil, + }, + { + name: "Valid0Domains", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 0, + 0, 0, 0, 0, 0, 0, 0, + }, + lifetime: 0, + domainNames: nil, + err: nil, + }, + { + name: "NoTrailingNull", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g', + }, + lifetime: 0, + domainNames: nil, + err: io.ErrUnexpectedEOF, + }, + { + name: "IncorrectLength", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g', + }, + lifetime: 0, + domainNames: nil, + err: io.ErrUnexpectedEOF, + }, + { + name: "IncorrectLengthWithNULL", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 7, 'a', 'b', 'c', 'd', 'e', 'f', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "LabelOfLength63", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk", + }, + err: nil, + }, + { + name: "LabelOfLength64", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "DomainNameOfLength255", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij", + }, + err: nil, + }, + { + name: "DomainNameOfLength256", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "StartingDigitForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, '9', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "StartingHyphenForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, '-', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "EndingHyphenForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', '-', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "EndingDigitForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', '9', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: []string{ + "ab9", + }, + err: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opt := NDPDNSSearchList(test.buf) + + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + domainNames, err := opt.DomainNames() + if !errors.Is(err, test.err) { + t.Errorf("opt.DomainNames() = %s", err) + } + if diff := cmp.Diff(domainNames, test.domainNames); diff != "" { + t.Errorf("mismatched domain names (-want +got):\n%s", diff) + } + }) + } +} + +func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) { + for r := rune(0); r <= 255; r++ { + t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) { + buf := []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 0 /* will be replaced */, 'c', + 0, + 0, 0, 0, + } + buf[8] = uint8(r) + opt := NDPDNSSearchList(buf) + + // As per RFC 1035 section 2.3.1, the label must only include ASCII + // letters, digits and hyphens (a-z, A-Z, 0-9, -). + var expectedErr error + re := regexp.MustCompile(`[a-zA-Z0-9-]`) + if !re.Match([]byte{byte(r)}) { + expectedErr = ErrNDPOptMalformedBody + } + + if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) { + t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody) + } + }) + } +} + +func TestNDPDNSSearchListOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 31, 3, 0, 0, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPDNSSearchList(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPDNSSearchListOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType) + } + + opt, ok := next.(NDPDNSSearchList) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next) + } + if got := opt.Type(); got != 31 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16777216*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + domainNames, err := opt.DomainNames() + if err != nil { + t.Errorf("opt.DomainNames() = %s", err) + } + if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" { + t.Errorf("domain names mismatch (-want +got):\n%s", diff) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions +// the iterator was returned for is malformed. +func TestNDPOptionsIterCheck(t *testing.T) { + tests := []struct { + name string + buf []byte + expectedErr error + }{ + { + name: "ZeroLengthField", + buf: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + expectedErr: ErrNDPOptMalformedHeader, + }, + { + name: "ValidSourceLinkLayerAddressOption", + buf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, + expectedErr: nil, + }, + { + name: "TooSmallSourceLinkLayerAddressOption", + buf: []byte{1, 1, 1, 2, 3, 4, 5}, + expectedErr: io.ErrUnexpectedEOF, + }, + { + name: "ValidTargetLinkLayerAddressOption", + buf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, + expectedErr: nil, + }, + { + name: "TooSmallTargetLinkLayerAddressOption", + buf: []byte{2, 1, 1, 2, 3, 4, 5}, + expectedErr: io.ErrUnexpectedEOF, + }, + { + name: "ValidPrefixInformation", + buf: []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + expectedErr: nil, + }, + { + name: "TooSmallPrefixInformation", + buf: []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, + }, + expectedErr: io.ErrUnexpectedEOF, + }, + { + name: "InvalidPrefixInformationLength", + buf: []byte{ + 3, 3, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + }, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", + buf: []byte{ + // Source Link-Layer Address. + 1, 1, 1, 2, 3, 4, 5, 6, + + // Target Link-Layer Address. + 2, 1, 7, 8, 9, 10, 11, 12, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + expectedErr: nil, + }, + { + name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + buf: []byte{ + // Source Link-Layer Address. + 1, 1, 1, 2, 3, 4, 5, 6, + + // Target Link-Layer Address. + 2, 1, 7, 8, 9, 10, 11, 12, + + // 255 is an unrecognized type. If 255 ends up + // being the type for some recognized type, + // update 255 to some other unrecognized value. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + expectedErr: nil, + }, + { + name: "InvalidRecursiveDNSServerCutsOffAddress", + buf: []byte{ + 25, 4, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 1, 2, 3, 4, 5, 6, 7, + }, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "InvalidRecursiveDNSServerInvalidLengthField", + buf: []byte{ + 25, 2, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, + }, + expectedErr: io.ErrUnexpectedEOF, + }, + { + name: "RecursiveDNSServerTooSmall", + buf: []byte{ + 25, 1, 0, 0, + 0, 0, 0, + }, + expectedErr: io.ErrUnexpectedEOF, + }, + { + name: "RecursiveDNSServerMulticast", + buf: []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "RecursiveDNSServerUnspecified", + buf: []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "DNSSearchListLargeCompliantRFC1035", + buf: []byte{ + 31, 33, 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', + 0, + }, + expectedErr: nil, + }, + { + name: "DNSSearchListNonCompliantRFC1035", + buf: []byte{ + 31, 33, 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "DNSSearchListValidSmall", + buf: []byte{ + 31, 2, 0, 0, + 0, 0, 0, 0, + 6, 'a', 'b', 'c', 'd', 'e', 'f', + 0, + }, + expectedErr: nil, + }, + { + name: "DNSSearchListTooSmall", + buf: []byte{ + 31, 1, 0, 0, + 0, 0, 0, + }, + expectedErr: io.ErrUnexpectedEOF, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + + if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) { + t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr) + } + + // test.buf may be malformed but we chose not to check + // the iterator so it must return true. + if _, err := opts.Iter(false); err != nil { + t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) + } + }) + } +} + +// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, +// this test does not actually check any of the option's getters, it simply +// checks the option Type and Body. We have other tests that tests the option +// field gettings given an option body and don't need to duplicate those tests +// here. +func TestNDPOptionsIter(t *testing.T) { + buf := []byte{ + // Source Link-Layer Address. + 1, 1, 1, 2, 3, 4, 5, 6, + + // Target Link-Layer Address. + 2, 1, 7, 8, 9, 10, 11, 12, + + // 255 is an unrecognized type. If 255 ends up being the type + // for some recognized type, update 255 to some other + // unrecognized value. Note, this option should be skipped when + // iterating. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + opts := NDPOptions(buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Test the first (Source Link-Layer) option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) + } + + // Test the next (Target Link-Layer) option. + next, done, err = it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + + // Test the next (Prefix Information) option. + // Note, the unrecognized option should be skipped. + next, done, err = it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) } } diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go new file mode 100644 index 000000000..6fe9a336b --- /dev/null +++ b/pkg/tcpip/header/ndpoptionidentifier_string.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT. + +package header + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[NDPSourceLinkLayerAddressOptionType-1] + _ = x[NDPTargetLinkLayerAddressOptionType-2] + _ = x[NDPPrefixInformationType-3] + _ = x[NDPRecursiveDNSServerOptionType-25] +} + +const ( + _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType" + _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType" +) + +var ( + _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94} +) + +func (i NDPOptionIdentifier) String() string { + switch { + case 1 <= i && i <= 3: + i -= 1 + return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]] + case i == 25: + return _NDPOptionIdentifier_name_1 + default: + return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" + } +} diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 82cfe785c..4c6f808e5 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -66,6 +66,14 @@ const ( TCPOptionSACK = 5 ) +// Option Lengths. +const ( + TCPOptionMSSLength = 4 + TCPOptionTSLength = 10 + TCPOptionWSLength = 3 + TCPOptionSackPermittedLength = 2 +) + // TCPFields contains the fields of a TCP packet. It is used to describe the // fields of a packet that needs to be encoded. type TCPFields struct { @@ -81,7 +89,8 @@ type TCPFields struct { // AckNum is the "acknowledgement number" field of a TCP packet. AckNum uint32 - // DataOffset is the "data offset" field of a TCP packet. + // DataOffset is the "data offset" field of a TCP packet. It is the length of + // the TCP header in bytes. DataOffset uint8 // Flags is the "flags" field of a TCP packet. @@ -213,7 +222,8 @@ func (b TCP) AckNumber() uint32 { return binary.BigEndian.Uint32(b[TCPAckNumOffset:]) } -// DataOffset returns the "data offset" field of the tcp header. +// DataOffset returns the "data offset" field of the tcp header. The return +// value is the length of the TCP header in bytes. func (b TCP) DataOffset() uint8 { return (b[TCPDataOffset] >> 4) * 4 } @@ -238,6 +248,11 @@ func (b TCP) Checksum() uint16 { return binary.BigEndian.Uint16(b[TCPChecksumOffset:]) } +// UrgentPointer returns the "urgent pointer" field of the tcp header. +func (b TCP) UrgentPointer() uint16 { + return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:]) +} + // SetSourcePort sets the "source port" field of the tcp header. func (b TCP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port) @@ -253,6 +268,37 @@ func (b TCP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum) } +// SetDataOffset sets the data offset field of the tcp header. headerLen should +// be the length of the TCP header in bytes. +func (b TCP) SetDataOffset(headerLen uint8) { + b[TCPDataOffset] = (headerLen / 4) << 4 +} + +// SetSequenceNumber sets the sequence number field of the tcp header. +func (b TCP) SetSequenceNumber(seqNum uint32) { + binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum) +} + +// SetAckNumber sets the ack number field of the tcp header. +func (b TCP) SetAckNumber(ackNum uint32) { + binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum) +} + +// SetFlags sets the flags field of the tcp header. +func (b TCP) SetFlags(flags uint8) { + b[TCPFlagsOffset] = flags +} + +// SetWindowSize sets the window size field of the tcp header. +func (b TCP) SetWindowSize(rcvwnd uint16) { + binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) +} + +// SetUrgentPoiner sets the window size field of the tcp header. +func (b TCP) SetUrgentPoiner(urgentPointer uint16) { + binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer) +} + // CalculateChecksum calculates the checksum of the tcp segment. // partialChecksum is the checksum of the network-layer pseudo-header // and the checksum of the segment data. @@ -456,14 +502,11 @@ func ParseTCPOptions(b []byte) TCPOptions { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeMSSOption(mss uint32, b []byte) int { - // mssOptionSize is the number of bytes in a valid MSS option. - const mssOptionSize = 4 - - if len(b) < mssOptionSize { + if len(b) < TCPOptionMSSLength { return 0 } - b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss) - return mssOptionSize + b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss) + return TCPOptionMSSLength } // EncodeWSOption encodes the WS TCP option with the WS value in the @@ -471,10 +514,10 @@ func EncodeMSSOption(mss uint32, b []byte) int { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeWSOption(ws int, b []byte) int { - if len(b) < 3 { + if len(b) < TCPOptionWSLength { return 0 } - b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws) + b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws) return int(b[1]) } @@ -483,10 +526,10 @@ func EncodeWSOption(ws int, b []byte) int { // just returns without encoding anything. It returns the number of bytes // written to the provided buffer. func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { - if len(b) < 10 { + if len(b) < TCPOptionTSLength { return 0 } - b[0], b[1] = TCPOptionTS, 10 + b[0], b[1] = TCPOptionTS, TCPOptionTSLength binary.BigEndian.PutUint32(b[2:], tsVal) binary.BigEndian.PutUint32(b[6:], tsEcr) return int(b[1]) @@ -497,11 +540,11 @@ func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { // encoding anything. It returns the number of bytes written to the provided // buffer. func EncodeSACKPermittedOption(b []byte) int { - if len(b) < 2 { + if len(b) < TCPOptionSackPermittedLength { return 0 } - b[0], b[1] = TCPOptionSACKPermitted, 2 + b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength return int(b[1]) } @@ -556,3 +599,23 @@ func AddTCPOptionPadding(options []byte, offset int) int { } return paddingToAdd } + +// Acceptable checks if a segment that starts at segSeq and has length segLen is +// "acceptable" for arriving in a receive window that starts at rcvNxt and ends +// before rcvAcc, according to the table on page 26 and 69 of RFC 793. +func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { + if rcvNxt == rcvAcc { + return segLen == 0 && segSeq == rcvNxt + } + if segLen == 0 { + // rcvWnd is incremented by 1 because that is Linux's behavior despite the + // RFC. + return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) + } + // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming + // the payload, so we'll accept any payload that overlaps the receieve window. + // segSeq < rcvAcc is more correct according to RFC, however, Linux does it + // differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior + // as Linux. + return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc) +} diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 74412c894..9339d637f 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -99,6 +99,11 @@ func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } +// SetLength sets the "length" field of the udp header. +func (b UDP) SetLength(length uint16) { + binary.BigEndian.PutUint16(b[udpLength:], length) +} + // CalculateChecksum calculates the checksum of the udp packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD deleted file mode 100644 index cc5f531e2..000000000 --- a/pkg/tcpip/iptables/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "iptables", - srcs = [ - "iptables.go", - "targets.go", - "types.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", - visibility = ["//visibility:public"], - deps = ["//pkg/tcpip/buffer"], -) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go deleted file mode 100644 index 68c68d4aa..000000000 --- a/pkg/tcpip/iptables/iptables.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2019 The gVisor authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package iptables supports packet filtering and manipulation via the iptables -// tool. -package iptables - -const ( - tablenameNat = "nat" - tablenameMangle = "mangle" -) - -// Chain names as defined by net/ipv4/netfilter/ip_tables.c. -const ( - chainNamePrerouting = "PREROUTING" - chainNameInput = "INPUT" - chainNameForward = "FORWARD" - chainNameOutput = "OUTPUT" - chainNamePostrouting = "POSTROUTING" -) - -// DefaultTables returns a default set of tables. Each chain is set to accept -// all packets. -func DefaultTables() IPTables { - return IPTables{ - Tables: map[string]Table{ - tablenameNat: Table{ - BuiltinChains: map[Hook]Chain{ - Prerouting: unconditionalAcceptChain(chainNamePrerouting), - Input: unconditionalAcceptChain(chainNameInput), - Output: unconditionalAcceptChain(chainNameOutput), - Postrouting: unconditionalAcceptChain(chainNamePostrouting), - }, - DefaultTargets: map[Hook]Target{ - Prerouting: UnconditionalAcceptTarget{}, - Input: UnconditionalAcceptTarget{}, - Output: UnconditionalAcceptTarget{}, - Postrouting: UnconditionalAcceptTarget{}, - }, - UserChains: map[string]Chain{}, - }, - tablenameMangle: Table{ - BuiltinChains: map[Hook]Chain{ - Prerouting: unconditionalAcceptChain(chainNamePrerouting), - Output: unconditionalAcceptChain(chainNameOutput), - }, - DefaultTargets: map[Hook]Target{ - Prerouting: UnconditionalAcceptTarget{}, - Output: UnconditionalAcceptTarget{}, - }, - UserChains: map[string]Chain{}, - }, - }, - Priorities: map[Hook][]string{ - Prerouting: []string{tablenameMangle, tablenameNat}, - Output: []string{tablenameMangle, tablenameNat}, - }, - } -} - -func unconditionalAcceptChain(name string) Chain { - return Chain{ - Name: name, - Rules: []Rule{ - Rule{ - Target: UnconditionalAcceptTarget{}, - }, - }, - } -} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go deleted file mode 100644 index 42a79ef9f..000000000 --- a/pkg/tcpip/iptables/types.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2019 The gVisor authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -// A Hook specifies one of the hooks built into the network stack. -// -// Userspace app Userspace app -// ^ | -// | v -// [Input] [Output] -// ^ | -// | v -// | routing -// | | -// | v -// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> -type Hook uint - -// These values correspond to values in include/uapi/linux/netfilter.h. -const ( - // Prerouting happens before a packet is routed to applications or to - // be forwarded. - Prerouting Hook = iota - - // Input happens before a packet reaches an application. - Input - - // Forward happens once it's decided that a packet should be forwarded - // to another host. - Forward - - // Output happens after a packet is written by an application to be - // sent out. - Output - - // Postrouting happens just before a packet goes out on the wire. - Postrouting - - // The total number of hooks. - NumHooks -) - -// A Verdict is returned by a rule's target to indicate how traversal of rules -// should (or should not) continue. -type Verdict int - -const ( - // Accept indicates the packet should continue traversing netstack as - // normal. - Accept Verdict = iota - - // Drop inicates the packet should be dropped, stopping traversing - // netstack. - Drop - - // Stolen indicates the packet was co-opted by the target and should - // stop traversing netstack. - Stolen - - // Queue indicates the packet should be queued for userspace processing. - Queue - - // Repeat indicates the packet should re-traverse the chains for the - // current hook. - Repeat - - // None indicates no verdict was reached. - None - - // Jump indicates a jump to another chain. - Jump - - // Continue indicates that traversal should continue at the next rule. - Continue - - // Return indicates that traversal should return to the calling chain. - Return -) - -// IPTables holds all the tables for a netstack. -type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table - - // Priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string -} - -// A Table defines a set of chains and hooks into the network stack. The -// currently supported tables are: -// * nat -// * mangle -type Table struct { - // BuiltinChains holds the un-deletable chains built into netstack. If - // a hook isn't present in the map, this table doesn't utilize that - // hook. - BuiltinChains map[Hook]Chain - - // DefaultTargets holds a target for each hook that will be executed if - // chain traversal doesn't yield a verdict. - DefaultTargets map[Hook]Target - - // UserChains holds user-defined chains for the keyed by name. Users - // can give their chains arbitrary names. - UserChains map[string]Chain - - // Chains maps names to chains for both builtin and user-defined chains. - // Its entries point to Chains already either in BuiltinChains or - // UserChains, and its purpose is to make looking up tables by name - // fast. - Chains map[string]*Chain - - // Metadata holds information about the Table that is useful to users - // of IPTables, but not to the netstack IPTables code itself. - metadata interface{} -} - -// ValidHooks returns a bitmap of the builtin hooks for the given table. -func (table *Table) ValidHooks() uint32 { - hooks := uint32(0) - for hook, _ := range table.BuiltinChains { - hooks |= 1 << hook - } - return hooks -} - -// Metadata returns the metadata object stored in table. -func (table *Table) Metadata() interface{} { - return table.metadata -} - -// SetMetadata sets the metadata object stored in table. -func (table *Table) SetMetadata(metadata interface{}) { - table.metadata = metadata -} - -// A Chain defines a list of rules for packet processing. When a packet -// traverses a chain, it is checked against each rule until either a rule -// returns a verdict or the chain ends. -// -// By convention, builtin chains end with a rule that matches everything and -// returns either Accept or Drop. User-defined chains end with Return. These -// aren't strictly necessary here, but the iptables tool writes tables this way. -type Chain struct { - // Name is the chain name. - Name string - - // Rules is the list of rules to traverse. - Rules []Rule -} - -// A Rule is a packet processing rule. It consists of two pieces. First it -// contains zero or more matchers, each of which is a specification of which -// packets this rule applies to. If there are no matchers in the rule, it -// applies to any packet. -type Rule struct { - // Matchers is the list of matchers for this rule. - Matchers []Matcher - - // Target is the action to invoke if all the matchers match the packet. - Target Target -} - -// A Matcher is the interface for matching packets. -type Matcher interface { - // Match returns whether the packet matches and whether the packet - // should be "hotdropped", i.e. dropped immediately. This is usually - // used for suspicious packets. - Match(hook Hook, packet buffer.VectorisedView, interfaceName string) (matches bool, hotdrop bool) -} - -// A Target is the interface for taking an action for a packet. -type Target interface { - // Action takes an action on the packet and returns a verdict on how - // traversal should (or should not) continue. If the return value is - // Jump, it also returns the name of the chain to jump to. - Action(packet buffer.VectorisedView) (Verdict, string) -} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 97a794986..39ca774ef 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -1,15 +1,16 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "channel", srcs = ["channel.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 14f197a77..c95aef63c 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -18,61 +18,177 @@ package channel import ( + "context" + + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Header buffer.View - Payload buffer.View - Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO + Pkt *stack.PacketBuffer + Proto tcpip.NetworkProtocolNumber + GSO *stack.GSO + Route stack.Route +} + +// Notification is the interface for receiving notification from the packet +// queue. +type Notification interface { + // WriteNotify will be called when a write happens to the queue. + WriteNotify() +} + +// NotificationHandle is an opaque handle to the registered notification target. +// It can be used to unregister the notification when no longer interested. +// +// +stateify savable +type NotificationHandle struct { + n Notification +} + +type queue struct { + // c is the outbound packet channel. + c chan PacketInfo + // mu protects fields below. + mu sync.RWMutex + notify []*NotificationHandle +} + +func (q *queue) Close() { + close(q.c) +} + +func (q *queue) Read() (PacketInfo, bool) { + select { + case p := <-q.c: + return p, true + default: + return PacketInfo{}, false + } +} + +func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { + select { + case pkt := <-q.c: + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + +func (q *queue) Write(p PacketInfo) bool { + wrote := false + select { + case q.c <- p: + wrote = true + default: + } + q.mu.Lock() + notify := q.notify + q.mu.Unlock() + + if wrote { + // Send notification outside of lock. + for _, h := range notify { + h.n.WriteNotify() + } + } + return wrote +} + +func (q *queue) Num() int { + return len(q.c) +} + +func (q *queue) AddNotify(notify Notification) *NotificationHandle { + q.mu.Lock() + defer q.mu.Unlock() + h := &NotificationHandle{n: notify} + q.notify = append(q.notify, h) + return h +} + +func (q *queue) RemoveNotify(handle *NotificationHandle) { + q.mu.Lock() + defer q.mu.Unlock() + // Make a copy, since we reads the array outside of lock when notifying. + notify := make([]*NotificationHandle, 0, len(q.notify)) + for _, h := range q.notify { + if h != handle { + notify = append(notify, h) + } + } + q.notify = notify } // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { - dispatcher stack.NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - GSO bool + dispatcher stack.NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + LinkEPCapabilities stack.LinkEndpointCapabilities - // C is where outbound packets are queued. - C chan PacketInfo + // Outbound packet queue. + q *queue } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - C: make(chan PacketInfo, size), + q: &queue{ + c: make(chan PacketInfo, size), + }, mtu: mtu, linkAddr: linkAddr, } } +// Close closes e. Further packet injections will panic. Reads continue to +// succeed until all packets are read. +func (e *Endpoint) Close() { + e.q.Close() +} + +// Read does non-blocking read one packet from the outbound packet queue. +func (e *Endpoint) Read() (PacketInfo, bool) { + return e.q.Read() +} + +// ReadContext does blocking read for one packet from the outbound packet queue. +// It can be cancelled by ctx, and in this case, it returns false. +func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { + return e.q.ReadContext(ctx) +} + // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { - select { - case <-e.C: - c++ - default: + if _, ok := e.Read(); !ok { return c } + c++ } } -// Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.InjectLinkAddr(protocol, "", vv) +// NumQueued returns the number of packet queued for outbound. +func (e *Endpoint) NumQueued() int { + return e.q.Num() +} + +// InjectInbound injects an inbound packet. +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil), nil /* linkHeader */) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -94,11 +210,7 @@ func (e *Endpoint) MTU() uint32 { // Capabilities implements stack.LinkEndpoint.Capabilities. func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - caps := stack.LinkEndpointCapabilities(0) - if e.GSO { - caps |= stack.CapabilityHardwareGSO - } - return caps + return e.LinkEPCapabilities } // GSOMaxSize returns the maximum GSO packet size. @@ -118,65 +230,81 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + // Clone r then release its resource so we only get the relevant fields from + // stack.Route without holding a reference to a NIC's endpoint. + route := r.Clone() + route.Release() p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: payload.ToView(), - GSO: gso, + Pkt: pkt, + Proto: protocol, + GSO: gso, + Route: route, } - select { - case e.C <- p: - default: - } + e.q.Write(p) return nil } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - payloadView := payload.ToView() +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Clone r then release its resource so we only get the relevant fields from + // stack.Route without holding a reference to a NIC's endpoint. + route := r.Clone() + route.Release() n := 0 -packetLoop: - for i := range hdrs { - hdr := &hdrs[i].Hdr - off := hdrs[i].Off - size := hdrs[i].Size + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: buffer.NewViewFromBytes(payloadView[off : off+size]), - GSO: gso, + Pkt: pkt, + Proto: protocol, + GSO: gso, + Route: route, } - select { - case e.C <- p: - n++ - default: - break packetLoop + if !e.q.Write(p) { + break } + n++ } return n, nil } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Header: packet.ToView(), - Proto: 0, - Payload: buffer.View{}, - GSO: nil, + Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }), + Proto: 0, + GSO: nil, } - select { - case e.C <- p: - default: - } + e.q.Write(p) return nil } // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} + +// AddNotify adds a notification target for receiving event about outgoing +// packets. +func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { + return e.q.AddNotify(notify) +} + +// RemoveNotify removes handle from the list of notification targets. +func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { + e.q.RemoveNotify(handle) +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 8fa9e3984..10072eac1 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -13,11 +12,11 @@ go_library( "mmap_unsafe.go", "packet_dispatchers.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ + "//pkg/binary", + "//pkg/iovec", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -31,12 +30,13 @@ go_test( name = "fdbased_test", size = "small", srcs = ["endpoint_test.go"], - embed = [":fdbased"], + library = ":fdbased", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index ae4858529..975309fc8 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,10 +41,12 @@ package fdbased import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/iovec" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -91,7 +93,7 @@ func (p PacketDispatchMode) String() string { case PacketMMap: return "PacketMMap" default: - return fmt.Sprintf("unknown packet dispatch mode %v", p) + return fmt.Sprintf("unknown packet dispatch mode '%d'", p) } } @@ -384,37 +386,46 @@ const ( _VIRTIO_NET_HDR_GSO_TCPV6 = 4 ) -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) } +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.hdrSize > 0 { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } + + var builder iovec.Builder + fd := e.fds[pkt.Hash%uint32(len(e.fds))] if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} - vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) if gso != nil { - vnetHdr.hdrLen = uint16(hdr.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen vnetHdr.csumOffset = gso.CsumOffset } - if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS { + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { switch gso.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 @@ -427,136 +438,129 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } } - return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView()) + vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + builder.Add(vnetHdrBuf) } - if payload.Size() == 0 { - return rawfile.NonBlockingWrite(e.fds[0], hdr.View()) + for _, v := range pkt.Views() { + builder.Add(v) } - - return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) + return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } -// WritePackets writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - var ethHdrBuf []byte - // hdr + data - iovLen := 2 - if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, +func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { + // Send a batch of packets through batchFD. + mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) + for _, pkt := range batch { + if e.hdrSize > 0 { + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) } - // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ - } - - n := len(hdrs) - - views := payload.Views() - /* - * Each bondary in views can add one more iovec. - * - * payload | | | | - * ----------------------------- - * packets | | | | | | | - * ----------------------------- - * iovecs | | | | | | | | | - */ - iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) - mmsgHdrs := make([]rawfile.MMsgHdr, n) - - iovecIdx := 0 - viewIdx := 0 - viewOff := 0 - off := 0 - nextOff := 0 - for i := range hdrs { - prevIovecIdx := iovecIdx - mmsgHdr := &mmsgHdrs[i] - mmsgHdr.Msg.Iov = &iovec[iovecIdx] - packetSize := hdrs[i].Size - hdr := &hdrs[i].Hdr - - off = hdrs[i].Off - if off != nextOff { - // We stop in a different point last time. - size := packetSize - viewIdx = 0 - viewOff = 0 - for size > 0 { - if size >= len(views[viewIdx]) { - viewIdx++ - viewOff = 0 - size -= len(views[viewIdx]) - } else { - viewOff = size - size = 0 + var vnetHdrBuf []byte + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + vnetHdr := virtioNetHdr{} + if pkt.GSOOptions != nil { + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) + if pkt.GSOOptions.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset + } + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data.Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + } + vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - } - nextOff = off + packetSize - - if ethHdrBuf != nil { - v := &iovec[iovecIdx] - v.Base = ðHdrBuf[0] - v.Len = uint64(len(ethHdrBuf)) - iovecIdx++ + vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) } - v := &iovec[iovecIdx] - hdrView := hdr.View() - v.Base = &hdrView[0] - v.Len = uint64(len(hdrView)) - iovecIdx++ - - for packetSize > 0 { - vec := &iovec[iovecIdx] - iovecIdx++ - - v := views[viewIdx] - vec.Base = &v[viewOff] - s := len(v) - viewOff - if s <= packetSize { - viewIdx++ - viewOff = 0 - } else { - s = packetSize - viewOff += s - } - vec.Len = uint64(s) - packetSize -= s + var builder iovec.Builder + builder.Add(vnetHdrBuf) + for _, v := range pkt.Views() { + builder.Add(v) } + iovecs := builder.Build() - mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + var mmsgHdr rawfile.MMsgHdr + mmsgHdr.Msg.Iov = &iovecs[0] + mmsgHdr.Msg.Iovlen = uint64(len(iovecs)) + mmsgHdrs = append(mmsgHdrs, mmsgHdr) } packets := 0 - for packets < n { - sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs) + for len(mmsgHdrs) > 0 { + sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) if err != nil { return packets, err } packets += sent mmsgHdrs = mmsgHdrs[sent:] } + return packets, nil } +// WritePackets writes outbound packets to the underlying file descriptors. If +// one is not currently writable, the packet is dropped. +// +// Being a batch API, each packet in pkts should have the following +// fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Preallocate to avoid repeated reallocation as we append to batch. + // batchSz is 47 because when SWGSO is in use then a single 65KB TCP + // segment can get split into 46 segments of 1420 bytes and a single 216 + // byte segment. + const batchSz = 47 + batch := make([]*stack.PacketBuffer, 0, batchSz) + batchFD := -1 + sentPackets := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if len(batch) == 0 { + batchFD = e.fds[pkt.Hash%uint32(len(e.fds))] + } + pktFD := e.fds[pkt.Hash%uint32(len(e.fds))] + if sendNow := pktFD != batchFD; !sendNow { + batch = append(batch, pkt) + continue + } + n, err := e.sendBatch(batchFD, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + batch = batch[:0] + batch = append(batch, pkt) + batchFD = pktFD + } + + if len(batch) != 0 { + n, err := e.sendBatch(batchFD, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + } + return sentPackets, nil +} + +// viewsEqual tests whether v1 and v2 refer to the same backing bytes. +func viewsEqual(vs1, vs2 []buffer.View) bool { + return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) +} + // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fds[0], packet.ToView()) +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) } // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. @@ -583,6 +587,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + // InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes // to the FD, but does not read from it. All reads come from injected packets. type InjectableEndpoint struct { @@ -598,8 +610,8 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 59378b96c..709f829c8 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -26,6 +26,7 @@ import ( "time" "unsafe" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,43 +44,75 @@ const ( ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents buffer.View - linkHeader buffer.View + Raddr tcpip.LinkAddress + Proto tcpip.NetworkProtocolNumber + Contents *stack.PacketBuffer +} + +type packetContents struct { + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View + Data buffer.View +} + +func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { + t.Helper() + if diff := cmp.Diff( + want, got, + cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { + if pk == nil { + return nil + } + return &packetContents{ + LinkHeader: pk.LinkHeader().View(), + NetworkHeader: pk.NetworkHeader().View(), + TransportHeader: pk.TransportHeader().View(), + Data: pk.Data.ToView(), + } + }), + ); diff != "" { + t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) + } } type context struct { - t *testing.T - fds [2]int - ep stack.LinkEndpoint - ch chan packetInfo - done chan struct{} + t *testing.T + readFDs []int + writeFDs []int + ep stack.LinkEndpoint + ch chan packetInfo + done chan struct{} } func newContext(t *testing.T, opt *Options) *context { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + firstFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + t.Fatalf("Socketpair failed: %v", err) + } + secondFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) if err != nil { t.Fatalf("Socketpair failed: %v", err) } - done := make(chan struct{}, 1) + done := make(chan struct{}, 2) opt.ClosedFunc = func(*tcpip.Error) { done <- struct{}{} } - opt.FDs = []int{fds[1]} + opt.FDs = []int{firstFDPair[1], secondFDPair[1]} ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } c := &context{ - t: t, - fds: fds, - ep: ep, - ch: make(chan packetInfo, 100), - done: done, + t: t, + readFDs: []int{firstFDPair[0], secondFDPair[0]}, + writeFDs: opt.FDs, + ep: ep, + ch: make(chan packetInfo, 100), + done: done, } ep.Attach(c) @@ -88,13 +121,22 @@ func newContext(t *testing.T, opt *Options) *context { } func (c *context) cleanup() { - syscall.Close(c.fds[0]) + for _, fd := range c.readFDs { + syscall.Close(fd) + } <-c.done - syscall.Close(c.fds[1]) + <-c.done + for _, fd := range c.writeFDs { + syscall.Close(fd) + } } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { - c.ch <- packetInfo{remote, protocol, vv.ToView(), linkHeader} +func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + c.ch <- packetInfo{remote, protocol, pkt} +} + +func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") } func TestNoEthernetProperties(t *testing.T) { @@ -137,7 +179,7 @@ func TestAddress(t *testing.T) { } } -func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { +func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) { c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) defer c.cleanup() @@ -145,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { RemoteLinkAddress: raddr, } - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) + // Build payload. + payload := buffer.NewView(plen) + if _, err := rand.Read(payload); err != nil { + t.Fatalf("rand.Read(payload): %s", err) } - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) + // Build packet buffer. + const netHdrLen = 100 + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, + Data: payload.ToVectorisedView(), + }) + pkt.Hash = hash + + // Build header. + b := pkt.NetworkHeader().Push(netHdrLen) + if _, err := rand.Read(b); err != nil { + t.Fatalf("rand.Read(b): %s", err) } - want := append(hdr.View(), payload...) + + // Write. + want := append(append(buffer.View(nil), b...), payload...) var gso *stack.GSO if gsoMaxSize != 0 { gso = &stack.GSO{ @@ -169,13 +220,14 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } - // Read from fd, then compare with what we wrote. + // Read from the corresponding FD, then compare with what we wrote. b = make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) + fd := c.readFDs[hash%uint32(len(c.readFDs))] + n, err := syscall.Read(fd, b) if err != nil { t.Fatalf("Read failed: %v", err) } @@ -236,7 +288,7 @@ func TestWritePacket(t *testing.T) { t.Run( fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), func(t *testing.T) { - testWritePacket(t, plen, eth, gso) + testWritePacket(t, plen, eth, gso, 0) }, ) } @@ -244,6 +296,27 @@ func TestWritePacket(t *testing.T) { } } +func TestHashedWritePacket(t *testing.T) { + lengths := []int{0, 100, 1000} + eths := []bool{true, false} + gsos := []uint32{0, 32768} + hashes := []uint32{0, 1} + for _, eth := range eths { + for _, plen := range lengths { + for _, gso := range gsos { + for _, hash := range hashes { + t.Run( + fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash), + func(t *testing.T) { + testWritePacket(t, plen, eth, gso, hash) + }, + ) + } + } + } + } +} + func TestPreserveSrcAddress(t *testing.T) { baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") @@ -256,16 +329,20 @@ func TestPreserveSrcAddress(t *testing.T) { LocalLinkAddress: baddr, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.VectorisedView{}, + }) + if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } // Read from the FD, then compare with what we wrote. b := make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) + n, err := syscall.Read(c.readFDs[0], b) if err != nil { t.Fatalf("Read failed: %v", err) } @@ -288,28 +365,29 @@ func TestDeliverPacket(t *testing.T) { defer c.cleanup() // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) + all := make([]byte, plen) + if _, err := rand.Read(all); err != nil { + t.Fatalf("rand.Read(all): %s", err) } - - var hdr header.Ethernet - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr = make(header.Ethernet, header.EthernetMinimumSize) + // Make it look like an IPv4 packet. + all[0] = 0x40 + + wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.NewViewFromBytes(all).ToVectorisedView(), + }) + if eth { + hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, Type: proto, }) - all = append(hdr, b...) + all = append(hdr, all...) } // Write packet via the file descriptor. - if _, err := syscall.Write(c.fds[0], all); err != nil { + if _, err := syscall.Write(c.readFDs[0], all); err != nil { t.Fatalf("Write failed: %v", err) } @@ -317,18 +395,15 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: b, - linkHeader: buffer.View(hdr), + Raddr: raddr, + Proto: proto, + Contents: wantPkt, } if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + want.Proto = header.IPv4ProtocolNumber + want.Raddr = "" } + checkPacketInfoEqual(t, pi, want) case <-time.After(10 * time.Second): t.Fatalf("Timed out waiting for packet") } @@ -455,3 +530,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) { } } + +// fakeNetworkDispatcher delivers packets to pkts. +type fakeNetworkDispatcher struct { + pkts []*stack.PacketBuffer +} + +func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + d.pkts = append(d.pkts, pkt) +} + +func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + +func TestDispatchPacketFormat(t *testing.T) { + for _, test := range []struct { + name string + newDispatcher func(fd int, e *endpoint) (linkDispatcher, error) + }{ + { + name: "readVDispatcher", + newDispatcher: newReadVDispatcher, + }, + { + name: "recvMMsgDispatcher", + newDispatcher: newRecvMMsgDispatcher, + }, + } { + t.Run(test.name, func(t *testing.T) { + // Create a socket pair to send/recv. + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + t.Fatal(err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + data := []byte{ + // Ethernet header. + 1, 2, 3, 4, 5, 60, + 1, 2, 3, 4, 5, 61, + 8, 0, + // Mock network header. + 40, 41, 42, 43, + } + err = syscall.Sendmsg(fds[1], data, nil, nil, 0) + if err != nil { + t.Fatal(err) + } + + // Create and run dispatcher once. + sink := &fakeNetworkDispatcher{} + d, err := test.newDispatcher(fds[0], &endpoint{ + hdrSize: header.EthernetMinimumSize, + dispatcher: sink, + }) + if err != nil { + t.Fatal(err) + } + if ok, err := d.dispatch(); !ok || err != nil { + t.Fatalf("d.dispatch() = %v, %v", ok, err) + } + + // Verify packet. + if got, want := len(sink.pkts), 1; got != want { + t.Fatalf("len(sink.pkts) = %d, want %d", got, want) + } + pkt := sink.pkts[0] + if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { + t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) + } + if got, want := pkt.Data.Size(), 4; got != want { + t.Errorf("pkt.Data.Size() = %d, want %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go index 97a477b61..df14eaad1 100644 --- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go +++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go @@ -17,16 +17,7 @@ package fdbased import ( - "reflect" "unsafe" ) const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{})) - -func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) { - sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) - sh.Data = uintptr(unsafe.Pointer(hdr)) - sh.Len = virtioNetHdrSize - sh.Cap = virtioNetHdrSize - return -} diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 554d45715..c475dda20 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -18,6 +18,7 @@ package fdbased import ( "encoding/binary" + "fmt" "syscall" "golang.org/x/sys/unix" @@ -25,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -169,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(pkt) + eth := header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -189,7 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } } - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}), buffer.View(eth)) + pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(pkt).ToVectorisedView(), + }) + if d.e.hdrSize > 0 { + if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok { + panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize)) + } + } + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 12168a1dc..8c3ca86d6 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.allocateViews(BufConfig) n, err := rawfile.BlockingReadv(d.fd, d.iovecs) - if err != nil { + if n == 0 || err != nil { return false, err } if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { // isn't used and it isn't in a view. n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(n, BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,11 +143,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, d.views[:used]) - vv.TrimFront(d.e.hdrSize) - - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -166,7 +167,7 @@ type recvMMsgDispatcher struct { // iovecs is an array of array of iovec records where each iovec base // pointer and length are initialzed to the corresponding view above, - // except when GSO is neabled then the first iovec in each array of + // except when GSO is enabled then the first iovec in each array of // iovecs points to a buffer for the vnet header which is stripped // before the views are passed up the stack for further processing. iovecs [][]syscall.Iovec @@ -265,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(k, int(n), BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -292,10 +298,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), d.views[k][:used]) - vv.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index 23e4d1418..6bf3805b7 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -1,12 +1,11 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "loopback", srcs = ["loopback.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index a3b48fa73..38aa694e4 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,36 +76,47 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - - // Because we're immediately turning around and writing the packet back to the - // rx path, we intentionally don't preserve the remote and local link - // addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + // Construct data as the unparsed portion for the loopback packet. + data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + + // Because we're immediately turning around and writing the packet back + // to the rx path, we intentionally don't preserve the remote and local + // link addresses from the stack.Route we're passed. + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, + }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt) return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if packet.Size() < header.EthernetMinimumSize { +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + // There should be an ethernet header at the beginning of vv. + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of packet. - linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize]) - packet.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), packet, buffer.View(linkHeader)) + linkHeader := header.Ethernet(hdr) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) return nil } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareLoopback +} + +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1bab380b0..e7493e5c5 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,18 +1,15 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "muxed", srcs = ["injectable.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) @@ -21,7 +18,7 @@ go_test( name = "muxed_test", size = "small", srcs = ["injectable_test.go"], - embed = [":muxed"], + library = ":muxed", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 682b60291..56a611825 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -18,6 +18,7 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -80,33 +81,33 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute } - return endpoint.WritePackets(r, gso, hdrs, payload, protocol) + return endpoint.WritePackets(r, gso, pkts, protocol) } // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, gso, hdr, payload, protocol) + return endpoint.WritePacket(r, gso, protocol, pkt) } return tcpip.ErrNoRoute } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (m *InjectableEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { // WriteRawPacket doesn't get a route or network address, so there's // nowhere to write this. return tcpip.ErrNoRoute @@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() { } } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unsupported operation") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 9cd300af8..3e4afcdad 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -46,12 +46,14 @@ func TestInjectableEndpointRawDispatch(t *testing.T) { func TestInjectableEndpointDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) + pkt.TransportHeader().Push(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -65,11 +67,14 @@ func TestInjectableEndpointDispatch(t *testing.T) { func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD new file mode 100644 index 000000000..2cdb23475 --- /dev/null +++ b/pkg/tcpip/link/nested/BUILD @@ -0,0 +1,32 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "nested", + srcs = [ + "nested.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "nested_test", + size = "small", + srcs = [ + "nested_test.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go new file mode 100644 index 000000000..d40de54df --- /dev/null +++ b/pkg/tcpip/link/nested/nested.go @@ -0,0 +1,152 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nested provides helpers to implement the pattern of nested +// stack.LinkEndpoints. +package nested + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// Endpoint is a wrapper around stack.LinkEndpoint and stack.NetworkDispatcher +// that can be used to implement nesting safely by providing lifecycle +// concurrency guards. +// +// See the tests in this package for example usage. +type Endpoint struct { + child stack.LinkEndpoint + embedder stack.NetworkDispatcher + + // mu protects dispatcher. + mu sync.RWMutex + dispatcher stack.NetworkDispatcher +} + +var _ stack.GSOEndpoint = (*Endpoint)(nil) +var _ stack.LinkEndpoint = (*Endpoint)(nil) +var _ stack.NetworkDispatcher = (*Endpoint)(nil) + +// Init initializes a nested.Endpoint that uses embedder as the dispatcher for +// child on Attach. +// +// See the tests in this package for example usage. +func (e *Endpoint) Init(child stack.LinkEndpoint, embedder stack.NetworkDispatcher) { + e.child = child + e.embedder = embedder +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverNetworkPacket(remote, local, protocol, pkt) + } +} + +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverOutboundPacket(remote, local, protocol, pkt) + } +} + +// Attach implements stack.LinkEndpoint. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + e.dispatcher = dispatcher + e.mu.Unlock() + // If we're attaching to a valid dispatcher, pass embedder as the dispatcher + // to our child, otherwise detach the child by giving it a nil dispatcher. + var pass stack.NetworkDispatcher + if dispatcher != nil { + pass = e.embedder + } + e.child.Attach(pass) +} + +// IsAttached implements stack.LinkEndpoint. +func (e *Endpoint) IsAttached() bool { + e.mu.RLock() + isAttached := e.dispatcher != nil + e.mu.RUnlock() + return isAttached +} + +// MTU implements stack.LinkEndpoint. +func (e *Endpoint) MTU() uint32 { + return e.child.MTU() +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.child.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return e.child.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.child.LinkAddress() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + return e.child.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + return e.child.WritePackets(r, gso, pkts, protocol) +} + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return e.child.WriteRawPacket(vv) +} + +// Wait implements stack.LinkEndpoint. +func (e *Endpoint) Wait() { + e.child.Wait() +} + +// GSOMaxSize implements stack.GSOEndpoint. +func (e *Endpoint) GSOMaxSize() uint32 { + if e, ok := e.child.(stack.GSOEndpoint); ok { + return e.GSOMaxSize() + } + return 0 +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.child.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.child.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go new file mode 100644 index 000000000..c1f9d308c --- /dev/null +++ b/pkg/tcpip/link/nested/nested_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nested_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type parentEndpoint struct { + nested.Endpoint +} + +var _ stack.LinkEndpoint = (*parentEndpoint)(nil) +var _ stack.NetworkDispatcher = (*parentEndpoint)(nil) + +type childEndpoint struct { + stack.LinkEndpoint + dispatcher stack.NetworkDispatcher +} + +var _ stack.LinkEndpoint = (*childEndpoint)(nil) + +func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + c.dispatcher = dispatcher +} + +func (c *childEndpoint) IsAttached() bool { + return c.dispatcher != nil +} + +type counterDispatcher struct { + count int +} + +var _ stack.NetworkDispatcher = (*counterDispatcher)(nil) + +func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + d.count++ +} + +func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + panic("unimplemented") +} + +func TestNestedLinkEndpoint(t *testing.T) { + const emptyAddress = tcpip.LinkAddress("") + + var ( + childEP childEndpoint + nestedEP parentEndpoint + disp counterDispatcher + ) + nestedEP.Endpoint.Init(&childEP, &nestedEP) + + if childEP.IsAttached() { + t.Error("On init, childEP.IsAttached() = true, want = false") + } + if nestedEP.IsAttached() { + t.Error("On init, nestedEP.IsAttached() = true, want = false") + } + + nestedEP.Attach(&disp) + if disp.count != 0 { + t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count) + } + if !childEP.IsAttached() { + t.Error("After attach, childEP.IsAttached() = false, want = true") + } + if !nestedEP.IsAttached() { + t.Error("After attach, nestedEP.IsAttached() = false, want = true") + } + + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) + if disp.count != 1 { + t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) + } + + nestedEP.Attach(nil) + if childEP.IsAttached() { + t.Error("After detach, childEP.IsAttached() = true, want = false") + } + if nestedEP.IsAttached() { + t.Error("After detach, nestedEP.IsAttached() = true, want = false") + } + + disp.count = 0 + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) + if disp.count != 0 { + t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) + } + +} diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD new file mode 100644 index 000000000..6fff160ce --- /dev/null +++ b/pkg/tcpip/link/packetsocket/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "packetsocket", + srcs = ["endpoint.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go new file mode 100644 index 000000000..3922c2a04 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packetsocket provides a link layer endpoint that provides the ability +// to loop outbound packets to any AF_PACKET sockets that may be interested in +// the outgoing packet. +package packetsocket + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type endpoint struct { + nested.Endpoint +} + +// New creates a new packetsocket LinkEndpoint. +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + e := &endpoint{} + e.Endpoint.Init(lower, e) + return e +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD new file mode 100644 index 000000000..1d0079bd6 --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "fifo", + srcs = [ + "endpoint.go", + "packet_buffer_queue.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sleep", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go new file mode 100644 index 000000000..fc1e34fc7 --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -0,0 +1,227 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fifo provides the implementation of data-link layer endpoints that +// wrap another endpoint and queues all outbound packets and asynchronously +// dispatches them to the lower endpoint. +package fifo + +import ( + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// endpoint represents a LinkEndpoint which implements a FIFO queue for all +// outgoing packets. endpoint can have 1 or more underlying queueDispatchers. +// All outgoing packets are consistenly hashed to a single underlying queue +// using the PacketBuffer.Hash if set, otherwise all packets are queued to the +// first queue to avoid reordering in case of missing hash. +type endpoint struct { + dispatcher stack.NetworkDispatcher + lower stack.LinkEndpoint + wg sync.WaitGroup + dispatchers []*queueDispatcher +} + +// queueDispatcher is responsible for dispatching all outbound packets in its +// queue. It will also smartly batch packets when possible and write them +// through the lower LinkEndpoint. +type queueDispatcher struct { + lower stack.LinkEndpoint + q *packetBufferQueue + newPacketWaker sleep.Waker + closeWaker sleep.Waker +} + +// New creates a new fifo link endpoint with the n queues with maximum +// capacity of queueLen. +func New(lower stack.LinkEndpoint, n int, queueLen int) stack.LinkEndpoint { + e := &endpoint{ + lower: lower, + } + // Create the required dispatchers + for i := 0; i < n; i++ { + qd := &queueDispatcher{ + q: &packetBufferQueue{limit: queueLen}, + lower: lower, + } + e.dispatchers = append(e.dispatchers, qd) + e.wg.Add(1) + go func() { + defer e.wg.Done() + qd.dispatchLoop() + }() + } + return e +} + +func (q *queueDispatcher) dispatchLoop() { + const newPacketWakerID = 1 + const closeWakerID = 2 + s := sleep.Sleeper{} + s.AddWaker(&q.newPacketWaker, newPacketWakerID) + s.AddWaker(&q.closeWaker, closeWakerID) + defer s.Done() + + const batchSize = 32 + var batch stack.PacketBufferList + for { + id, ok := s.Fetch(true) + if ok && id == closeWakerID { + return + } + for pkt := q.q.dequeue(); pkt != nil; pkt = q.q.dequeue() { + batch.PushBack(pkt) + if batch.Len() < batchSize && !q.q.empty() { + continue + } + // We pass a protocol of zero here because each packet carries its + // NetworkProtocol. + q.lower.WritePackets(nil /* route */, nil /* gso */, batch, 0 /* protocol */) + for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() { + pkt.EgressRoute.Release() + batch.Remove(pkt) + } + batch.Reset() + } + } +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) +} + +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) +} + +// Attach implements stack.LinkEndpoint.Attach. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + e.lower.Attach(e) +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. +func (e *endpoint) MTU() uint32 { + return e.lower.MTU() +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.lower.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. +func (e *endpoint) MaxHeaderLength() uint16 { + return e.lower.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.lower.LinkAddress() +} + +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + if gso, ok := e.lower.(stack.GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + // WritePacket caller's do not set the following fields in PacketBuffer + // so we populate them here. + newRoute := r.Clone() + pkt.EgressRoute = &newRoute + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol + d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] + if !d.q.enqueue(pkt) { + return tcpip.ErrNoBufferSpace + } + d.newPacketWaker.Assert() + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +// +// Being a batch API, each packet in pkts should have the following fields +// populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + enqueued := 0 + for pkt := pkts.Front(); pkt != nil; { + d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] + nxt := pkt.Next() + // Since qdisc can hold onto a packet for long we should Clone + // the route here to ensure it doesn't get released while the + // packet is still in our queue. + newRoute := pkt.EgressRoute.Clone() + pkt.EgressRoute = &newRoute + if !d.q.enqueue(pkt) { + if enqueued > 0 { + d.newPacketWaker.Assert() + } + return enqueued, tcpip.ErrNoBufferSpace + } + pkt = nxt + enqueued++ + d.newPacketWaker.Assert() + } + return enqueued, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // TODO(gvisor.dev/issue/3267): Queue these packets as well once + // WriteRawPacket takes PacketBuffer instead of VectorisedView. + return e.lower.WriteRawPacket(vv) +} + +// Wait implements stack.LinkEndpoint.Wait. +func (e *endpoint) Wait() { + e.lower.Wait() + + // The linkEP is gone. Teardown the outbound dispatcher goroutines. + for i := range e.dispatchers { + e.dispatchers[i].closeWaker.Assert() + } + + e.wg.Wait() +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go new file mode 100644 index 000000000..eb5abb906 --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fifo + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// packetBufferQueue is a bounded, thread-safe queue of PacketBuffers. +// +type packetBufferQueue struct { + mu sync.Mutex + list stack.PacketBufferList + limit int + used int +} + +// emptyLocked determines if the queue is empty. +// Preconditions: q.mu must be held. +func (q *packetBufferQueue) emptyLocked() bool { + return q.used == 0 +} + +// empty determines if the queue is empty. +func (q *packetBufferQueue) empty() bool { + q.mu.Lock() + r := q.emptyLocked() + q.mu.Unlock() + + return r +} + +// setLimit updates the limit. No PacketBuffers are immediately dropped in case +// the queue becomes full due to the new limit. +func (q *packetBufferQueue) setLimit(limit int) { + q.mu.Lock() + q.limit = limit + q.mu.Unlock() +} + +// enqueue adds the given packet to the queue. +// +// Returns true when the PacketBuffer is successfully added to the queue, in +// which case ownership of the reference is transferred to the queue. And +// returns false if the queue is full, in which case ownership is retained by +// the caller. +func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { + q.mu.Lock() + r := q.used < q.limit + if r { + q.list.PushBack(s) + q.used++ + } + q.mu.Unlock() + + return r +} + +// dequeue removes and returns the next PacketBuffer from queue, if one exists. +// Ownership is transferred to the caller. +func (q *packetBufferQueue) dequeue() *stack.PacketBuffer { + q.mu.Lock() + s := q.list.Front() + if s != nil { + q.list.Remove(s) + q.used-- + } + q.mu.Unlock() + + return s +} diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 05c7b8024..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,10 +12,7 @@ go_library( "errors.go", "rawfile_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index dda3b10a6..99313ee25 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.14 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 44e25d475..f4c32c2da 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -66,39 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { return nil } -// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a -// single syscall. It fails if partial data is written. -func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { - // If the is no second buffer, issue a regular write. - if len(b2) == 0 { - return NonBlockingWrite(fd, b1) - } - - // We have two buffers. Build the iovec that represents them and issue - // a writev syscall. - iovec := [3]syscall.Iovec{ - { - Base: &b1[0], - Len: uint64(len(b1)), - }, - { - Base: &b2[0], - Len: uint64(len(b2)), - }, - } - iovecLen := uintptr(2) - - if len(b3) > 0 { - iovecLen++ - iovec[2].Base = &b3[0] - iovec[2].Len = uint64(len(b3)) - } - +// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. +// It fails if partial data is written. +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { + iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { return TranslateErrno(e) } - return nil } diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 0a5ea3dc4..13243ebbb 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,12 +10,10 @@ go_library( "sharedmem_unsafe.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", - visibility = [ - "//:sandbox", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -31,8 +28,9 @@ go_test( srcs = [ "sharedmem_test.go", ], - embed = [":sharedmem"], + library = ":sharedmem", deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 330ed5e94..87020ec08 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,8 +10,7 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], ) go_test( @@ -20,5 +18,6 @@ go_test( srcs = [ "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", + deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index 59ef69a8b..dc239a0d0 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -18,8 +18,9 @@ import ( "math/rand" "reflect" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSimpleReadWrite(t *testing.T) { diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index de1ce043d..3ba06af73 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,8 +8,7 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip/link/sharedmem/pipe", @@ -22,7 +20,7 @@ go_test( srcs = [ "queue_test.go", ], - embed = [":queue"], + library = ":queue", deps = [ "//pkg/tcpip/link/sharedmem/pipe", ], diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 279e2b457..7fb8a6c49 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -23,11 +23,11 @@ package sharedmem import ( - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -183,26 +183,33 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - // Add the ethernet header here. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) - v := payload.ToView() + views := pkt.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(hdr.View(), v) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -213,16 +220,16 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { - v := packet.ToView() +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + views := vv.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(v, buffer.View{}) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -268,13 +275,18 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { rxb[i].Size = e.bufferSize } - if n < header.EthernetMinimumSize { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { continue } + eth := header.Ethernet(hdr) // Send packet up the stack. - eth := header.Ethernet(b) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), buffer.View(eth)) + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) } // Clean state. @@ -283,3 +295,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { e.completed.Done() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index f3e9705c9..22d5c97f1 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -22,11 +22,11 @@ import ( "math/rand" "os" "strings" - "sync" "syscall" "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -131,19 +131,22 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ - addr: remoteLinkAddr, - proto: proto, - vv: vv.Clone(nil), - linkHeader: linkHeader, + addr: remoteLinkAddr, + proto: proto, + vv: pkt.Data.Clone(nil), }) c.mu.Unlock() c.packetCh <- struct{}{} } +func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (c *testContext) cleanup() { c.ep.Close() closeFDs(&c.txCfg) @@ -263,18 +266,23 @@ func TestSimpleSend(t *testing.T) { for iters := 1000; iters > 0; iters-- { func() { + hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000) + // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) + hdrBuf := buffer.NewView(hdrLen) randomFill(hdrBuf) - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) + data := buffer.NewView(dataLen) + randomFill(data) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()), + Data: data.ToVectorisedView(), + }) + copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -311,7 +319,7 @@ func TestSimpleSend(t *testing.T) { // Compare contents skipping the ethernet header added by the // endpoint. - merged := append(hdrBuf, buf...) + merged := append(hdrBuf, data...) if uint32(len(contents)) < pi.Size { t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) } @@ -338,12 +346,14 @@ func TestPreserveSrcAddressInSend(t *testing.T) { LocalLinkAddress: newLocalLinkAddress, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + ReserveHeaderBytes: header.EthernetMinimumSize, + }) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -395,9 +405,12 @@ func TestFillTxQueue(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -411,8 +424,11 @@ func TestFillTxQueue(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -436,8 +452,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -456,8 +475,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -471,8 +493,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -494,8 +519,11 @@ func TestFillTxMemory(t *testing.T) { // we fill the memory. ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -510,8 +538,11 @@ func TestFillTxMemory(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -535,8 +566,11 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Each packet is uses up one buffer, so write as many as possible // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -547,17 +581,22 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write a two-buffer packet. It must fail. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buffer.NewView(bufferSize).ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } // Attempt to write the one-buffer packet again. It must succeed. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -640,7 +679,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index 6b8d7859d..44f421c2d 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "syscall" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -76,9 +77,9 @@ func (t *tx) cleanup() { syscall.Munmap(t.data) } -// transmit sends a packet made up of up to two buffers. Returns a boolean that -// specifies whether the packet was successfully transmitted. -func (t *tx) transmit(a, b []byte) bool { +// transmit sends a packet made of bufs. Returns a boolean that specifies +// whether the packet was successfully transmitted. +func (t *tx) transmit(bufs ...buffer.View) bool { // Pull completions from the tx queue and add their buffers back to the // pool so that we can reuse them. for { @@ -93,7 +94,10 @@ func (t *tx) transmit(a, b []byte) bool { } bSize := t.bufs.entrySize - total := uint32(len(a) + len(b)) + total := uint32(0) + for _, data := range bufs { + total += uint32(len(data)) + } bufCount := (total + bSize - 1) / bSize // Allocate enough buffers to hold all the data. @@ -115,7 +119,7 @@ func (t *tx) transmit(a, b []byte) bool { // Copy data into allocated buffers. nBuf := buf var dBuf []byte - for _, data := range [][]byte{a, b} { + for _, data := range bufs { for len(data) > 0 { if len(dBuf) == 0 { dBuf = t.data[nBuf.Offset:][:nBuf.Size] diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 1756114e6..7cbc305e7 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,15 +8,13 @@ go_library( "pcap.go", "sniffer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 39757ea2a..4fb127978 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -21,11 +21,9 @@ package sniffer import ( - "bytes" "encoding/binary" "fmt" "io" - "os" "sync/atomic" "time" @@ -33,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -42,26 +41,29 @@ import ( // LogPackets must be accessed atomically. var LogPackets uint32 = 1 -// LogPacketsToFile is a flag used to enable or disable logging packets to a -// pcap file. Valid values are 0 or 1. A file must have been specified when the +// LogPacketsToPCAP is a flag used to enable or disable logging packets to a +// pcap writer. Valid values are 0 or 1. A writer must have been specified when the // sniffer was created for this flag to have effect. // -// LogPacketsToFile must be accessed atomically. -var LogPacketsToFile uint32 = 1 +// LogPacketsToPCAP must be accessed atomically. +var LogPacketsToPCAP uint32 = 1 type endpoint struct { - dispatcher stack.NetworkDispatcher - lower stack.LinkEndpoint - file *os.File + nested.Endpoint + writer io.Writer maxPCAPLen uint32 } +var _ stack.GSOEndpoint = (*endpoint)(nil) +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.NetworkDispatcher = (*endpoint)(nil) + // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - return &endpoint{ - lower: lower, - } + sniffer := &endpoint{} + sniffer.Endpoint.Init(lower, sniffer) + return sniffer } func zoneOffset() (int32, error) { @@ -92,132 +94,72 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { }) } -// NewWithFile creates a new sniffer link-layer endpoint. It wraps around -// another endpoint and logs packets and they traverse the endpoint. +// NewWithWriter creates a new sniffer link-layer endpoint. It wraps around +// another endpoint and logs packets as they traverse the endpoint. // -// Packets can be logged to file in the pcap format. A sniffer created -// with this function will not emit packets using the standard log -// package. +// Packets are logged to writer in the pcap format. A sniffer created with this +// function will not emit packets using the standard log package. // // snapLen is the maximum amount of a packet to be saved. Packets with a length -// less than or equal too snapLen will be saved in their entirety. Longer +// less than or equal to snapLen will be saved in their entirety. Longer // packets will be truncated to snapLen. -func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) { - if err := writePCAPHeader(file, snapLen); err != nil { +func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (stack.LinkEndpoint, error) { + if err := writePCAPHeader(writer, snapLen); err != nil { return nil, err } - return &endpoint{ - lower: lower, - file: file, + sniffer := &endpoint{ + writer: writer, maxPCAPLen: snapLen, - }, nil + } + sniffer.Endpoint.Init(lower, sniffer) + return sniffer, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { - if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First(), nil) - } - if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - vs := vv.Views() - length := vv.Size() - if length > int(e.maxPCAPLen) { - length = int(e.maxPCAPLen) - } - - buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { - panic(err) - } - for _, v := range vs { - if length == 0 { - break - } - if len(v) > length { - v = v[:length] - } - if _, err := buf.Write([]byte(v)); err != nil { - panic(err) - } - length -= len(v) - } - if _, err := e.file.Write(buf.Bytes()); err != nil { - panic(err) - } - } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader) -} - -// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher -// and registers with the lower endpoint as its dispatcher so that "e" is called -// for inbound packets. -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - e.lower.Attach(e) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dumpPacket("recv", nil, protocol, pkt) + e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *endpoint) IsAttached() bool { - return e.dispatcher != nil +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } -// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the -// lower endpoint. -func (e *endpoint) MTU() uint32 { - return e.lower.MTU() -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the -// request to the lower endpoint. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.lower.Capabilities() -} - -// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards -// the request to the lower endpoint. -func (e *endpoint) MaxHeaderLength() uint16 { - return e.lower.MaxHeaderLength() -} - -func (e *endpoint) LinkAddress() tcpip.LinkAddress { - return e.lower.LinkAddress() -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.lower.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 -} - -func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) { - if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, hdr.View(), gso) +func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + writer := e.writer + if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { + logPacket(prefix, protocol, pkt, gso) } - if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - hdrBuf := hdr.View() - length := len(hdrBuf) + payload.Size() - if length > int(e.maxPCAPLen) { - length = int(e.maxPCAPLen) + if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { + totalLength := pkt.Size() + length := totalLength + if max := int(e.maxPCAPLen); length > max { + length = max } - - buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil { + if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil { panic(err) } - if len(hdrBuf) > length { - hdrBuf = hdrBuf[:length] - } - if _, err := buf.Write(hdrBuf); err != nil { - panic(err) + write := func(b []byte) { + if len(b) > length { + b = b[:length] + } + for len(b) != 0 { + n, err := writer.Write(b) + if err != nil { + panic(err) + } + b = b[n:] + length -= n + } } - length -= len(hdrBuf) - logVectorisedView(payload, length, buf) - if _, err := e.file.Write(buf.Bytes()); err != nil { - panic(err) + for _, v := range pkt.Views() { + if length == 0 { + break + } + write(v) } } } @@ -225,68 +167,30 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload bu // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - e.dumpPacket(gso, hdr, payload, protocol) - return e.lower.WritePacket(r, gso, hdr, payload, protocol) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.dumpPacket("send", gso, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - view := payload.ToView() - for _, d := range hdrs { - e.dumpPacket(gso, d.Hdr, buffer.NewVectorisedView(d.Size, []buffer.View{view[d.Off:][:d.Size]}), protocol) +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.dumpPacket("send", gso, protocol, pkt) } - return e.lower.WritePackets(r, gso, hdrs, payload, protocol) + return e.Endpoint.WritePackets(r, gso, pkts, protocol) } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { - if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) - } - if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - length := packet.Size() - if length > int(e.maxPCAPLen) { - length = int(e.maxPCAPLen) - } - - buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(packet.Size()))); err != nil { - panic(err) - } - logVectorisedView(packet, length, buf) - if _, err := e.file.Write(buf.Bytes()); err != nil { - panic(err) - } - } - return e.lower.WriteRawPacket(packet) -} - -func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) { - if length <= 0 { - return - } - for _, v := range vv.Views() { - if len(v) > length { - v = v[:length] - } - n, err := buf.Write(v) - if err != nil { - panic(err) - } - length -= n - if length == 0 { - return - } - } +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + return e.Endpoint.WriteRawPacket(vv) } -// Wait implements stack.LinkEndpoint.Wait. -func (*endpoint) Wait() {} - -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -295,30 +199,47 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Examine the packet using a new VV. Backing storage must not be written. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( - "%s arp %v (%v) -> %v (%v) valid:%v", + "%s arp %s (%s) -> %s (%s) valid:%t", prefix, tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), @@ -338,7 +259,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -366,12 +291,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie icmpType = "info reply" } } - log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -397,13 +326,17 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6RedirectMsg: icmpType = "redirect message" } - log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) - if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) + if fragmentOffset == 0 { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) @@ -412,15 +345,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) - if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) + if fragmentOffset == 0 { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if offset > len(tcp) && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp)) + if offset > vv.Size() && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size()) break } @@ -436,7 +373,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie flagsStr[i] = ' ' } } - details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) + details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) if flags&header.TCPFlagSyn != 0 { details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0)) } else { @@ -445,7 +382,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie } default: - log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto) + log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto) return } @@ -453,5 +390,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie details += fmt.Sprintf(" gso: %+v", gso) } - log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) + log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 92dce8fac..6c137f693 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,12 +1,26 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "tun", - srcs = ["tun_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", - visibility = [ - "//visibility:public", + srcs = [ + "device.go", + "protocol.go", + "tun_unsafe.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/refs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + "//pkg/waiter", ], ) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go new file mode 100644 index 000000000..3b1510a33 --- /dev/null +++ b/pkg/tcpip/link/tun/device.go @@ -0,0 +1,383 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // drivers/net/tun.c:tun_net_init() + defaultDevMtu = 1500 + + // Queue length for outbound packet, arriving at fd side for read. Overflow + // causes packet drops. gVisor implementation-specific. + defaultDevOutQueueLen = 1024 +) + +var zeroMAC [6]byte + +// Device is an opened /dev/net/tun device. +// +// +stateify savable +type Device struct { + waiter.Queue + + mu sync.RWMutex `state:"nosave"` + endpoint *tunEndpoint + notifyHandle *channel.NotificationHandle + flags uint16 +} + +// beforeSave is invoked by stateify. +func (d *Device) beforeSave() { + d.mu.Lock() + defer d.mu.Unlock() + // TODO(b/110961832): Restore the device to stack. At this moment, the stack + // is not savable. + if d.endpoint != nil { + panic("/dev/net/tun does not support save/restore when a device is associated with it.") + } +} + +// Release implements fs.FileOperations.Release. +func (d *Device) Release(ctx context.Context) { + d.mu.Lock() + defer d.mu.Unlock() + + // Decrease refcount if there is an endpoint associated with this file. + if d.endpoint != nil { + d.endpoint.RemoveNotify(d.notifyHandle) + d.endpoint.DecRef(ctx) + d.endpoint = nil + } +} + +// SetIff services TUNSETIFF ioctl(2) request. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.endpoint != nil { + return syserror.EINVAL + } + + // Input validations. + isTun := flags&linux.IFF_TUN != 0 + isTap := flags&linux.IFF_TAP != 0 + supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) + if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { + return syserror.EINVAL + } + + prefix := "tun" + if isTap { + prefix = "tap" + } + + linkCaps := stack.CapabilityNone + if isTap { + linkCaps |= stack.CapabilityResolutionRequired + } + + endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) + if err != nil { + return syserror.EINVAL + } + + d.endpoint = endpoint + d.notifyHandle = d.endpoint.AddNotify(d) + d.flags = flags + return nil +} + +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { + for { + // 1. Try to attach to an existing NIC. + if name != "" { + if nic, found := s.GetNICByName(name); found { + endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if !ok { + // Not a NIC created by tun device. + return nil, syserror.EOPNOTSUPP + } + if !endpoint.TryIncRef() { + // Race detected: NIC got deleted in between. + continue + } + return endpoint, nil + } + } + + // 2. Creating a new NIC. + id := tcpip.NICID(s.UniqueID()) + endpoint := &tunEndpoint{ + Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), + stack: s, + nicID: id, + name: name, + isTap: prefix == "tap", + } + endpoint.Endpoint.LinkEPCapabilities = linkCaps + if endpoint.name == "" { + endpoint.name = fmt.Sprintf("%s%d", prefix, id) + } + err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ + Name: endpoint.name, + }) + switch err { + case nil: + return endpoint, nil + case tcpip.ErrDuplicateNICID: + // Race detected: A NIC has been created in between. + continue + default: + return nil, syserror.EINVAL + } + } +} + +// Write inject one inbound packet to the network interface. +func (d *Device) Write(data []byte) (int64, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return 0, syserror.EBADFD + } + if !endpoint.IsAttached() { + return 0, syserror.EIO + } + + dataLen := int64(len(data)) + + // Packet information. + var pktInfoHdr PacketInfoHeader + if !d.hasFlags(linux.IFF_NO_PI) { + if len(data) < PacketInfoHeaderSize { + // Ignore bad packet. + return dataLen, nil + } + pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize]) + data = data[PacketInfoHeaderSize:] + } + + // Ethernet header (TAP only). + var ethHdr header.Ethernet + if d.hasFlags(linux.IFF_TAP) { + if len(data) < header.EthernetMinimumSize { + // Ignore bad packet. + return dataLen, nil + } + ethHdr = header.Ethernet(data[:header.EthernetMinimumSize]) + data = data[header.EthernetMinimumSize:] + } + + // Try to determine network protocol number, default zero. + var protocol tcpip.NetworkProtocolNumber + switch { + case pktInfoHdr != nil: + protocol = pktInfoHdr.Protocol() + case ethHdr != nil: + protocol = ethHdr.Type() + } + + // Try to determine remote link address, default zero. + var remote tcpip.LinkAddress + switch { + case ethHdr != nil: + remote = ethHdr.SourceAddress() + default: + remote = tcpip.LinkAddress(zeroMAC[:]) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: len(ethHdr), + Data: buffer.View(data).ToVectorisedView(), + }) + copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr) + endpoint.InjectLinkAddr(protocol, remote, pkt) + return dataLen, nil +} + +// Read reads one outgoing packet from the network interface. +func (d *Device) Read() ([]byte, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return nil, syserror.EBADFD + } + + for { + info, ok := endpoint.Read() + if !ok { + return nil, syserror.ErrWouldBlock + } + + v, ok := d.encodePkt(&info) + if !ok { + // Ignore unsupported packet. + continue + } + return v, nil + } +} + +// encodePkt encodes packet for fd side. +func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { + var vv buffer.VectorisedView + + // Packet information. + if !d.hasFlags(linux.IFF_NO_PI) { + hdr := make(PacketInfoHeader, PacketInfoHeaderSize) + hdr.Encode(&PacketInfoFields{ + Protocol: info.Proto, + }) + vv.AppendView(buffer.View(hdr)) + } + + // If the packet does not already have link layer header, and the route + // does not exist, we can't compute it. This is possibly a raw packet, tun + // device doesn't support this at the moment. + if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { + return nil, false + } + + // Ethernet header (TAP only). + if d.hasFlags(linux.IFF_TAP) { + // Add ethernet header if not provided. + if info.Pkt.LinkHeader().View().IsEmpty() { + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) + } + vv.AppendView(info.Pkt.LinkHeader().View()) + } + + // Append upper headers. + vv.AppendView(info.Pkt.NetworkHeader().View()) + vv.AppendView(info.Pkt.TransportHeader().View()) + // Append data payload. + vv.Append(info.Pkt.Data) + + return vv.ToView(), true +} + +// Name returns the name of the attached network interface. Empty string if +// unattached. +func (d *Device) Name() string { + d.mu.RLock() + defer d.mu.RUnlock() + if d.endpoint != nil { + return d.endpoint.name + } + return "" +} + +// Flags returns the flags set for d. Zero value if unset. +func (d *Device) Flags() uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + return d.flags +} + +func (d *Device) hasFlags(flags uint16) bool { + return d.flags&flags == flags +} + +// Readiness implements watier.Waitable.Readiness. +func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask { + if mask&waiter.EventIn != 0 { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint != nil && endpoint.NumQueued() == 0 { + mask &= ^waiter.EventIn + } + } + return mask & (waiter.EventIn | waiter.EventOut) +} + +// WriteNotify implements channel.Notification.WriteNotify. +func (d *Device) WriteNotify() { + d.Notify(waiter.EventIn) +} + +// tunEndpoint is the link endpoint for the NIC created by the tun device. +// +// It is ref-counted as multiple opening files can attach to the same NIC. +// The last owner is responsible for deleting the NIC. +type tunEndpoint struct { + *channel.Endpoint + + refs.AtomicRefCount + + stack *stack.Stack + nicID tcpip.NICID + name string + isTap bool +} + +// DecRef decrements refcount of e, removes NIC if refcount goes to 0. +func (e *tunEndpoint) DecRef(ctx context.Context) { + e.DecRefWithDestructor(ctx, func(context.Context) { + e.stack.RemoveNIC(e.nicID) + }) +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.isTap { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.isTap { + return + } + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + hdr := &header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: protocol, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = e.LinkAddress() + } + + eth.Encode(hdr) +} + +// MaxHeaderLength returns the maximum size of the link layer header. +func (e *tunEndpoint) MaxHeaderLength() uint16 { + if e.isTap { + return header.EthernetMinimumSize + } + return 0 +} diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go new file mode 100644 index 000000000..89d9d91a9 --- /dev/null +++ b/pkg/tcpip/link/tun/protocol.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // PacketInfoHeaderSize is the size of the packet information header. + PacketInfoHeaderSize = 4 + + offsetFlags = 0 + offsetProtocol = 2 +) + +// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is +// not set. +type PacketInfoFields struct { + Flags uint16 + Protocol tcpip.NetworkProtocolNumber +} + +// PacketInfoHeader is the wire representation of the packet information sent if +// IFF_NO_PI flag is not set. +type PacketInfoHeader []byte + +// Encode encodes f into h. +func (h PacketInfoHeader) Encode(f *PacketInfoFields) { + binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags) + binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol)) +} + +// Flags returns the flag field in h. +func (h PacketInfoHeader) Flags() uint16 { + return binary.BigEndian.Uint16(h[offsetFlags:]) +} + +// Protocol returns the protocol field in h. +func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber { + return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:])) +} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0746dc8ec..ee84c3d96 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,14 +7,12 @@ go_library( srcs = [ "waitable.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/gate", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) @@ -25,10 +22,11 @@ go_test( srcs = [ "waitable_test.go", ], - embed = [":waitable"], + library = ":waitable", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index a04fc1062..b152a0f26 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -50,12 +51,21 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader) + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) e.dispatchGate.Leave() } @@ -99,12 +109,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, gso, hdr, payload, protocol) + err := e.lower.WritePacket(r, gso, protocol, pkt) e.writeGate.Leave() return err } @@ -112,23 +122,23 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { - return len(hdrs), nil + return pkts.Len(), nil } - n, err := e.lower.WritePackets(r, gso, hdrs, payload, protocol) + n, err := e.lower.WritePackets(r, gso, pkts, protocol) e.writeGate.Leave() return n, err } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WriteRawPacket(packet) + err := e.lower.WriteRawPacket(vv) e.writeGate.Leave() return err } @@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() { // Wait implements stack.LinkEndpoint.Wait. func (e *Endpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 5f0f8fa2d..94827fc56 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -35,10 +36,14 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatchCount++ } +func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.attachCount++ e.dispatcher = dispatcher @@ -65,45 +70,55 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - e.writeCount += len(hdrs) - return len(hdrs), nil +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += pkts.Len() + return pkts.Len(), nil } -func (e *countedEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { e.writeCount++ return nil } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unimplemented") +} + // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +135,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9d16ff8c9..46083925c 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) @@ -12,6 +12,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index df0d3a8c0..eddf7b725 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,15 +1,11 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "arp", srcs = ["arp.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 46178459e..920872c3f 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -42,7 +42,8 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { - nicid tcpip.NICID + protocol *protocol + nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache } @@ -58,43 +59,39 @@ func (e *endpoint) MTU() uint32 { } func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.linkEP.Capabilities() } -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &stack.NetworkEndpointID{ProtocolAddress} -} - -func (e *endpoint) PrefixLen() int { - return 0 -} - func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.ARPSize } func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketDescriptor, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - v := vv.First() - h := header.ARP(v) +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return } @@ -102,23 +99,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } - hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - pkt := header.ARP(hdr.Prepend(header.ARPSize)) - pkt.SetIPv4OverEthernet() - pkt.SetOp(header.ARPReply) - copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) - copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + }) + packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) + _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) } } @@ -135,76 +134,77 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { - if addrWithPrefix.Address != ProtocolAddress { - return nil, tcpip.ErrBadLocalAddress - } +func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { return &endpoint{ - nicid: nicid, + protocol: p, + nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, - }, nil + } } -// LinkAddressProtocol implements stack.LinkAddressResolver. +// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -// LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ - RemoteLinkAddress: broadcastMAC, + RemoteLinkAddress: remoteLinkAddr, + } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) - h := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + }) + h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) copy(h.HardwareAddressSender(), linkEP.LinkAddress()) copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } -// ResolveStaticAddress implements stack.LinkAddressResolver. +// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { - return broadcastMAC, true + return header.EthernetBroadcastAddress, true } if header.IsV4MulticastAddress(addr) { - // RFC 1112 Host Extensions for IP Multicasting - // - // 6.4. Extensions to an Ethernet Local Network Module: - // - // An IP host group address is mapped to an Ethernet multicast - // address by placing the low-order 23-bits of the IP address - // into the low-order 23 bits of the Ethernet multicast address - // 01-00-5E-00-00-00 (hex). - return tcpip.LinkAddress([]byte{ - 0x01, - 0x00, - 0x5e, - addr[header.IPv4AddressSize-3] & 0x7f, - addr[header.IPv4AddressSize-2], - addr[header.IPv4AddressSize-1], - }), true + return header.EthernetAddressFromMulticastIPv4Address(addr), true } - return "", false + return tcpip.LinkAddress([]byte(nil)), false } -// SetOption implements NetworkProtocol. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.NetworkProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements NetworkProtocol. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.NetworkProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// Parse implements stack.NetworkProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + _, ok = pkt.NetworkHeader().Consume(header.ARPSize) + if !ok { + return 0, false, false + } + return 0, false, true +} // NewProtocol returns an ARP network protocol. func NewProtocol() stack.NetworkProtocol { diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 88b57ec03..c2c3e6891 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,6 +15,7 @@ package arp_test import ( + "context" "strconv" "testing" "time" @@ -31,10 +32,14 @@ import ( ) const ( - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") + stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d") + stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") + stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") + stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + + defaultChannelSize = 1 + defaultMTU = 65536 ) type testContext struct { @@ -49,8 +54,7 @@ func newTestContext(t *testing.T) *testContext { TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, }) - const defaultMTU = 65536 - ep := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) wep := stack.LinkEndpoint(ep) if testing.Verbose() { @@ -83,7 +87,7 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP.C) + c.linkEP.Close() } func TestDirectRequest(t *testing.T) { @@ -102,21 +106,23 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + })) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + pi, _ := c.linkEP.ReadContext(context.Background()) + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pkt.Header) + rep := header.ARP(pi.Pkt.NetworkHeader().View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) } if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { @@ -132,12 +138,53 @@ func TestDirectRequest(t *testing.T) { } inject(stackAddrBad) - select { - case pkt := <-c.linkEP.C: + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - case <-time.After(100 * time.Millisecond): - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. + } +} + +func TestLinkAddressRequest(t *testing.T) { + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: stackLinkAddr2, + expectLinkAddr: stackLinkAddr2, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: header.EthernetBroadcastAddress, + }, + } + + for _, test := range tests { + p := arp.NewProtocol() + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) + if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } } } diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 2cad0a0b6..d1c728ccf 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -24,10 +23,10 @@ go_library( "reassembler.go", "reassembler_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", ], @@ -41,14 +40,6 @@ go_test( "fragmentation_test.go", "reassembler_test.go", ], - embed = [":fragmentation"], + library = ":fragmentation", deps = ["//pkg/tcpip/buffer"], ) - -filegroup( - name = "autogen", - srcs = [ - "reassembler_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 6da5238ec..1827666c5 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -17,28 +17,58 @@ package fragmentation import ( + "errors" "fmt" "log" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) -// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. -const DefaultReassembleTimeout = 30 * time.Second +const ( + // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. + DefaultReassembleTimeout = 30 * time.Second -// HighFragThreshold is the threshold at which we start trimming old -// fragmented packets. Linux uses a default value of 4 MB. See -// net.ipv4.ipfrag_high_thresh for more information. -const HighFragThreshold = 4 << 20 // 4MB + // HighFragThreshold is the threshold at which we start trimming old + // fragmented packets. Linux uses a default value of 4 MB. See + // net.ipv4.ipfrag_high_thresh for more information. + HighFragThreshold = 4 << 20 // 4MB -// LowFragThreshold is the threshold we reach to when we start dropping -// older fragmented packets. It's important that we keep enough room for newer -// packets to be re-assembled. Hence, this needs to be lower than -// HighFragThreshold enough. Linux uses a default value of 3 MB. See -// net.ipv4.ipfrag_low_thresh for more information. -const LowFragThreshold = 3 << 20 // 3MB + // LowFragThreshold is the threshold we reach to when we start dropping + // older fragmented packets. It's important that we keep enough room for newer + // packets to be re-assembled. Hence, this needs to be lower than + // HighFragThreshold enough. Linux uses a default value of 3 MB. See + // net.ipv4.ipfrag_low_thresh for more information. + LowFragThreshold = 3 << 20 // 3MB + + // minBlockSize is the minimum block size for fragments. + minBlockSize = 1 +) + +var ( + // ErrInvalidArgs indicates to the caller that that an invalid argument was + // provided. + ErrInvalidArgs = errors.New("invalid args") +) + +// FragmentID is the identifier for a fragment. +type FragmentID struct { + // Source is the source address of the fragment. + Source tcpip.Address + + // Destination is the destination address of the fragment. + Destination tcpip.Address + + // ID is the identification value of the fragment. + // + // This is a uint32 because IPv6 uses a 32-bit identification value. + ID uint32 + + // The protocol for the packet. + Protocol uint8 +} // Fragmentation is the main structure that other modules // of the stack should use to implement IP Fragmentation. @@ -46,14 +76,17 @@ type Fragmentation struct { mu sync.Mutex highLimit int lowLimit int - reassemblers map[uint32]*reassembler + reassemblers map[FragmentID]*reassembler rList reassemblerList size int timeout time.Duration + blockSize uint16 } // NewFragmentation creates a new Fragmentation. // +// blockSize specifies the fragment block size, in bytes. +// // highMemoryLimit specifies the limit on the memory consumed // by the fragments stored by Fragmentation (overhead of internal data-structures // is not accounted). Fragments are dropped when the limit is reached. @@ -64,7 +97,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -73,17 +106,46 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t lowMemoryLimit = 0 } + if blockSize < minBlockSize { + blockSize = minBlockSize + } + return &Fragmentation{ - reassemblers: make(map[uint32]*reassembler), + reassemblers: make(map[FragmentID]*reassembler), highLimit: highMemoryLimit, lowLimit: lowMemoryLimit, timeout: reassemblingTimeout, + blockSize: blockSize, } } -// Process processes an incoming fragment belonging to an ID -// and returns a complete packet when all the packets belonging to that ID have been received. -func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { +// Process processes an incoming fragment belonging to an ID and returns a +// complete packet when all the packets belonging to that ID have been received. +// +// [first, last] is the range of the fragment bytes. +// +// first must be a multiple of the block size f is configured with. The size +// of the fragment data must be a multiple of the block size, unless there are +// no fragments following this fragment (more set to false). +func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { + if first > last { + return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + } + + if first%f.blockSize != 0 { + return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + } + + fragmentSize := last - first + 1 + if more && fragmentSize%f.blockSize != 0 { + return buffer.VectorisedView{}, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + } + + if l := vv.Size(); l < int(fragmentSize) { + return buffer.VectorisedView{}, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + } + vv.CapLength(int(fragmentSize)) + f.mu.Lock() r, ok := f.reassemblers[id] if ok && r.tooOld(f.timeout) { @@ -115,10 +177,12 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. if f.size > f.highLimit { - tail := f.rList.Back() - for f.size > f.lowLimit && tail != nil { + for f.size > f.lowLimit { + tail := f.rList.Back() + if tail == nil { + break + } f.release(tail) - tail = tail.Prev() } } f.mu.Unlock() diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 72c0f53be..9eedd33c4 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -15,6 +15,7 @@ package fragmentation import ( + "errors" "reflect" "testing" "time" @@ -33,7 +34,7 @@ func vv(size int, pieces ...string) buffer.VectorisedView { } type processInput struct { - id uint32 + id FragmentID first uint16 last uint16 more bool @@ -53,8 +54,8 @@ var processTestCases = []struct { { comment: "One ID", in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -64,10 +65,10 @@ var processTestCases = []struct { { comment: "Two IDs", in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")}, + {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -81,7 +82,7 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout) for i, in := range c.in { vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) if err != nil { @@ -110,14 +111,14 @@ func TestFragmentationProcess(t *testing.T) { func TestReassemblingTimeout(t *testing.T) { timeout := time.Millisecond - f := NewFragmentation(1024, 512, timeout) + f := NewFragmentation(minBlockSize, 1024, 512, timeout) // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, vv(1, "0")) // Sleep more than the timeout. time.Sleep(2 * timeout) // Send another fragment that completes a packet. // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) + _, done, err := f.Process(FragmentID{}, 1, 1, false, vv(1, "1")) if err != nil { t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) } @@ -127,35 +128,35 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(3, 1, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout) // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, vv(1, "0")) // Send first fragment with id = 1. - f.Process(1, 0, 0, true, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, vv(1, "1")) // Send first fragment with id = 2. - f.Process(2, 0, 0, true, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, vv(1, "2")) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(3, 0, 0, true, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, vv(1, "3")) - if _, ok := f.reassemblers[0]; ok { + if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") } - if _, ok := f.reassemblers[1]; ok { + if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { t.Errorf("Memory limits are not respected: id=1 has not been evicted.") } - if _, ok := f.reassemblers[3]; !ok { + if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") } } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(1, 0, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout) // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, vv(1, "0")) // Send the same packet again. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, vv(1, "0")) got := f.size want := 1 @@ -163,3 +164,97 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) } } + +func TestErrors(t *testing.T) { + tests := []struct { + name string + blockSize uint16 + first uint16 + last uint16 + more bool + data string + err error + }{ + { + name: "exact block size without more", + blockSize: 2, + first: 2, + last: 3, + more: false, + data: "01", + }, + { + name: "exact block size with more", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "01", + }, + { + name: "exact block size with more and extra data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "012", + }, + { + name: "exact block size with more and too little data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size with more", + blockSize: 2, + first: 2, + last: 2, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size without more", + blockSize: 2, + first: 2, + last: 2, + more: false, + data: "0", + }, + { + name: "first not a multiple of block size", + blockSize: 2, + first: 3, + last: 4, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + { + name: "first more than last", + blockSize: 2, + first: 4, + last: 3, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout) + _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, vv(len(test.data), test.data)) + if !errors.Is(err, test.err) { + t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, _, %v), want = (_, _, %v)", test.first, test.last, test.more, test.data, err, test.err) + } + if done { + t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, true, _), want = (_, false, _)", test.first, test.last, test.more, test.data) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9e002e396..50d30bbf0 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -32,7 +32,7 @@ type hole struct { type reassembler struct { reassemblerEntry - id uint32 + id FragmentID size int mu sync.Mutex holes []hole @@ -42,7 +42,7 @@ type reassembler struct { creationTime time.Time } -func newReassembler(id uint32) *reassembler { +func newReassembler(id FragmentID) *reassembler { r := &reassembler{ id: id, holes: make([]hole, 0, 16), diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index 7eee0710d..dff7c9dcb 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -94,7 +94,7 @@ var holesTestCases = []struct { func TestUpdateHoles(t *testing.T) { for _, c := range holesTestCases { - r := newReassembler(0) + r := newReassembler(FragmentID{}) for _, i := range c.in { r.updateHoles(i.first, i.last, i.more) } diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD index e6db5c0b0..872165866 100644 --- a/pkg/tcpip/network/hash/BUILD +++ b/pkg/tcpip/network/hash/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "hash", srcs = ["hash.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash", visibility = ["//visibility:public"], deps = [ "//pkg/rand", diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go index 6a215938b..8f65713c5 100644 --- a/pkg/tcpip/network/hash/hash.go +++ b/pkg/tcpip/network/hash/hash.go @@ -80,12 +80,12 @@ func IPv4FragmentHash(h header.IPv4) uint32 { // RFC 2640 (sec 4.5) is not very sharp on this aspect. // As a reference, also Linux ignores the protocol to compute // the hash (inet6_hash_frag). -func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 { +func IPv6FragmentHash(h header.IPv6, id uint32) uint32 { t := h.SourceAddress() y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 t = h.DestinationAddress() z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 - return Hash3Words(f.ID(), y, z, hashIV) + return Hash3Words(id, y, z, hashIV) } func rol32(v, shift uint32) uint32 { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 666d8b92a..9007346fe 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -41,6 +42,7 @@ const ( ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + nicID = 1 ) // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. @@ -96,16 +98,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -150,50 +152,60 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(hdr.View()) + h := header.IPv4(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(hdr.View()) + h := header.IPv6(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, payload, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data, srcAddr, dstAddr) return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, hdr []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { +func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testObject) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv4.ProtocolNumber, local) + s.CreateNIC(nicID, loopback.New()) + s.AddAddress(nicID, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) + return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { @@ -201,24 +213,45 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv6.ProtocolNumber, local) + s.CreateNIC(nicID, loopback.New()) + s.AddAddress(nicID, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) + return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) +} + +func buildDummyStack(t *testing.T) *stack.Stack { + t.Helper() + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) + e := channel.New(0, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err) + } + + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err) + } + + return s } func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, nil, &o, buildDummyStack(t)) + defer ep.Close() // Allocate and initialize the payload view. payload := buffer.NewView(100) @@ -226,8 +259,11 @@ func TestIPv4Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -239,7 +275,11 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -247,10 +287,8 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + defer ep.Close() totalLen := header.IPv4MinimumSize + 30 view := buffer.NewView(totalLen) @@ -279,7 +317,13 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -291,7 +335,7 @@ func TestIPv4ReceiveControl(t *testing.T) { name string expectedCount int fragmentOffset uint16 - code uint8 + code header.ICMPv4Code expectedTyp stack.ControlType expectedExtra uint32 trunc int @@ -313,10 +357,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) defer ep.Close() const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize @@ -366,8 +407,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -378,10 +418,8 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + defer ep.Close() totalLen := header.IPv4MinimumSize + 24 @@ -430,13 +468,25 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag1.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag2.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -445,10 +495,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) + defer ep.Close() // Allocate and initialize the payload view. payload := buffer.NewView(100) @@ -456,8 +504,11 @@ func TestIPv6Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -469,7 +520,11 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -477,10 +532,8 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + defer ep.Close() totalLen := header.IPv6MinimumSize + 30 view := buffer.NewView(totalLen) @@ -509,7 +562,13 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -525,7 +584,7 @@ func TestIPv6ReceiveControl(t *testing.T) { expectedCount int fragmentOffset *uint16 typ header.ICMPv6Type - code uint8 + code header.ICMPv6Code expectedTyp stack.ControlType expectedExtra uint32 trunc int @@ -552,11 +611,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) defer ep.Close() dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize @@ -618,15 +673,26 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } }) } } + +// truncatedPacket returns a PacketBuffer based on a truncated view. If view, +// after truncation, is large enough to hold a network header, it makes part of +// view the packet's NetworkHeader and the rest its Data. Otherwise all of view +// becomes Data. +func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { + v := view[:len(view)-trunc] + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + _, _ = pkt.NetworkHeader().Consume(netHdrLen) + return pkt +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 58e537aad..d142b4ffa 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,10 +8,7 @@ go_library( "icmp.go", "ipv4.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", @@ -38,5 +34,6 @@ go_test( "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 50b363dc4..b5659a36b 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -24,8 +24,12 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -33,13 +37,14 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // false. // // Drop packet if it doesn't have the basic IPv4 header or if the - // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + // original source address doesn't match an address we own. + src := hdr.SourceAddress() + if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { return } - hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,16 +52,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + pkt.Data.TrimFront(hlen) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() - if len(v) < header.ICMPv4MinimumSize { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } @@ -73,29 +81,64 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } + // Make a copy of data before pkt gets sent to raw socket. + // DeliverTransportPacket will take ownership of pkt. + replyData := pkt.Data.Clone(nil) + replyData.TrimFront(header.ICMPv4MinimumSize) + // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) - - vv := vv.Clone(nil) - vv.TrimFront(header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv4EchoReply) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + + remoteLinkAddr := r.RemoteLinkAddress + + // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP + // source address MUST be one of its own IP addresses (but not a broadcast + // or multicast address). + localAddr := r.LocalAddress + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + localAddr = "" + } + + r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // Use the remote link address from the incoming packet. + r.ResolveWith(remoteLinkAddr) + + // Prepare a reply packet. + icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize) + copy(icmpHdr, h) + icmpHdr.SetType(header.ICMPv4EchoReply) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0))) + dataVV := buffer.View(icmpHdr).ToVectorisedView() + dataVV.Append(replyData) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: dataVV, + }) + + // Send out the reply packet. sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -104,19 +147,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { + case header.ICMPv4HostUnreachable: + e.handleControl(stack.ControlNoRoute, 0, pkt) + case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 90f4406e5..79872ec9a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -44,31 +44,29 @@ const ( // buckets is the number of identifier buckets. buckets = 2048 + + // The size of a fragment block, in bytes, as per RFC 791 section 3.1, + // page 14. + fragmentblockSize = 8 ) type endpoint struct { - nicid tcpip.NICID - id stack.NetworkEndpointID - prefixLen int - linkEP stack.LinkEndpoint - dispatcher stack.TransportDispatcher - fragmentation *fragmentation.Fragmentation - protocol *protocol + nicID tcpip.NICID + linkEP stack.LinkEndpoint + dispatcher stack.TransportDispatcher + protocol *protocol + stack *stack.Stack } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { - e := &endpoint{ - nicid: nicid, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, - dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), - protocol: p, - } - - return e, nil +func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { + return &endpoint{ + nicID: nicID, + linkEP: linkEP, + dispatcher: dispatcher, + protocol: p, + stack: st, + } } // DefaultTTL is the default time-to-live value for this endpoint. @@ -89,17 +87,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid -} - -// ID returns the ipv4 endpoint ID. -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &e.id -} - -// PrefixLen returns the ipv4 endpoint subnet prefix length in bits. -func (e *endpoint) PrefixLen() int { - return e.prefixLen + return e.nicID } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and @@ -116,14 +104,18 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in hdr but does not assume -// that only the IP header is in hdr. It assumes that the input packet's stated -// length matches the length of the hdr+payload. mtu includes the IP header and -// options. This does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { +// write. It assumes that the IP header is already present in pkt.NetworkHeader. +// pkt.TransportHeader may be set. mtu includes the IP header and options. This +// does not support the DontFragment IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(hdr.View()) + ip := header.IPv4(pkt.NetworkHeader().View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -137,76 +129,88 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := hdr.AvailableLength() + + // Keep the length reserved for link-layer, we need to create fragments with + // the same reserved length. + reservedForLink := pkt.AvailableHeaderBytes() + + // Destroy the packet, pull all payloads out for fragmentation. + transHeader, data := pkt.TransportHeader().View(), pkt.Data + + // Where possible, the first fragment that is sent has the same + // number of bytes reserved for header as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. + transFitsFirst := len(transHeader) <= innerMTU + for i := 0; i < n; i++ { - // Where possible, the first fragment that is sent has the same - // hdr.UsedLength() as the input packet. The link-layer endpoint may depends - // on this for looking at, eg, L4 headers. - h := ip - if i > 0 { - hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) - copy(h, ip[:ip.HeaderLength()]) + reserve := reservedForLink + int(ip.HeaderLength()) + if i == 0 && transFitsFirst { + // Reserve for transport header if it's going to be put in the first + // fragment. + reserve += len(transHeader) + } + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: reserve, + }) + fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + + // Copy data for the fragment. + avail := innerMTU + + if n := len(transHeader); n > 0 { + if n > avail { + n = avail + } + if i == 0 && transFitsFirst { + copy(fragPkt.TransportHeader().Push(n), transHeader) + } else { + fragPkt.Data.AppendView(transHeader[:n:n]) + } + transHeader = transHeader[n:] + avail -= n + } + + if avail > 0 { + n := data.Size() + if n > avail { + n = avail + } + data.ReadToVV(&fragPkt.Data, n) + avail -= n } + + copied := uint16(innerMTU - avail) + + // Set lengths in header and calculate checksum. + h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) + copy(h, ip) if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + copied) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) - offset += uint16(innerMTU) - if i > 0 { - newPayload := payload.Clone([]buffer.View{}) - newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayload.Size()) - continue - } - // Special handling for the first fragment because it comes from the hdr. - if outerMTU >= hdr.UsedLength() { - // This fragment can fit all of hdr and possibly some of payload, too. - newPayload := payload.Clone([]buffer.View{}) - newPayloadLength := outerMTU - hdr.UsedLength() - newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayloadLength) - } else { - // The fragment is too small to fit all of hdr. - startOfHdr := hdr - startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) - emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of hdr into the payload that remains to be sent. - restOfHdr := hdr.View()[outerMTU:] - tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(payload) - payload = tmp + offset += copied + + // Send out the fragment. + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + return err } + r.Stats().IP.PacketsSent.Increment() } return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) { - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payloadSize) - id := uint32(0) - if length > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) - } +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + length := uint16(pkt.Size()) + // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic + // datagrams. Since the DF bit is never being set here, all datagrams + // are non-atomic and need an ID. + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, @@ -218,28 +222,49 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - e.addIPHeader(r, &hdr, payload.Size(), params) - - if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { + e.addIPHeader(r, pkt, params) + + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + // iptables is telling us to drop the packet. + return nil + } + + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than + // only NATted packets, but removing this check short circuits broadcasts + // before they are sent out to other hosts. + if pkt.NatDone { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) + return nil + } + } + + if r.Loop&stack.PacketLoop != 0 { loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + e.HandlePacket(&loopedR, pkt) loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } - if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } - if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { return err } r.Stats().IP.PacketsSent.Increment() @@ -247,34 +272,76 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } - if loop&stack.PacketOut == 0 { - return len(hdrs), nil + if r.Loop&stack.PacketOut == 0 { + return pkts.Len(), nil + } + + for pkt := pkts.Front(); pkt != nil; { + e.addIPHeader(r, pkt, params) + pkt = pkt.Next() + } + + nicName := e.stack.FindNICNameFromID(e.NICID()) + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + if len(dropped) == 0 && len(natPkts) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err } - for i := range hdrs { - e.addIPHeader(r, &hdrs[i].Hdr, hdrs[i].Size, params) + // Slow Path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + ep.HandlePacket(&route, pkt) + n++ + continue + } + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(payload.First()) - if !ip.IsValid(payload.Size()) { + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) + if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } // Always set the total length. - ip.SetTotalLength(uint16(payload.Size())) + ip.SetTotalLength(uint16(pkt.Data.Size())) // Set the source address when zero. if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { @@ -287,51 +354,49 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // Set the packet ID when zero. if ip.ID() == 0 { - id := uint32(0) - if payload.Size() > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) + // RFC 6864 section 4.3 mandates uniqueness of ID values for + // non-atomic datagrams, so assign an ID to all such datagrams + // according to the definition given in RFC 6864 section 4. + if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } - ip.SetID(uint16(id)) } // Always set the checksum. ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + if r.Loop&stack.PacketLoop != 0 { + e.HandlePacket(r, pkt.Clone()) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } - // If we want to send the packet to a link-layer, - // we have to reserve space for an Ethernet header. - hdr := buffer.NewPrependableFromView(payload.ToView(), int(e.linkEP.MaxHeaderLength())) r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() - h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.IPv4(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - hlen := int(h.HeaderLength()) - tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + // iptables filtering. All packets that reach here are intended for + // this machine and will not be forwarded. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + // iptables is telling us to drop the packet. + return + } - more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 - if more || h.FragmentOffset() != 0 { - if vv.Size() == 0 { + if h.More() || h.FragmentOffset() != 0 { + if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -339,10 +404,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and vv.size() causes a wrap - // around resulting in last being less than the offset. + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. if last < h.FragmentOffset() { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -350,7 +415,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } var ready bool var err error - vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + pkt.Data, ready, err = e.protocol.fragmentation.Process( + // As per RFC 791 section 2.3, the identification value is unique + // for a source-destination pair and protocol. + fragmentation.FragmentID{ + Source: h.SourceAddress(), + Destination: h.DestinationAddress(), + ID: uint32(h.ID()), + Protocol: h.Protocol(), + }, + h.FragmentOffset(), + last, + h.More(), + pkt.Data, + ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -362,12 +440,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. @@ -381,6 +458,8 @@ type protocol struct { // uint8 portion of it is meaningful and it must be accessed // atomically. defaultTTL uint32 + + fragmentation *fragmentation.Fragmentation } // Number returns the ipv4 protocol number. @@ -436,6 +515,45 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv4(hdr) + + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return 0, false, false + } + ipHdr = header.IPv4(hdr) + + // If this is a fragment, don't bother parsing the transport header. + parseTransportHeader := true + if ipHdr.More() || ipHdr.FragmentOffset() != 0 { + parseTransportHeader = false + } + + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return ipHdr.TransportProtocol(), parseTransportHeader, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { @@ -467,5 +585,10 @@ func NewProtocol() stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} + return &protocol{ + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 99f84acd7..197e3bc51 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,9 +17,11 @@ package ipv4_test import ( "bytes" "encoding/hex" + "fmt" "math/rand" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -90,15 +92,11 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much +// makeRandPkt generates a randomize packet. hdrLength indicates how much // data should already be in the header before WritePacket. extraLength // indicates how much extra space should be in the header. The payload is made // from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - +func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer { var views []buffer.View totalLength := 0 for _, s := range viewSizes { @@ -107,18 +105,26 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. views = append(views, newView) totalLength += s } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLength + extraLength, + Data: buffer.NewVectorisedView(totalLength, views), + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt } // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { +func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) + source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) + vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) + source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -131,8 +137,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe var reassembledPayload []byte for i, packet := range packets { // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) + allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -143,12 +148,22 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + if i == 0 { + got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() + // sourcePacketInfo does not have NetworkHeader added, simulate one. + want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() + // Check that it kept the transport header in packet.TransportHeader if + // it fits in the first fragment. + if want < int(mtu) && got != want { + t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + } } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { + if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } + if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { + t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { @@ -173,7 +188,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe type errorChannel struct { *channel.Endpoint - Ch chan packetInfo + Ch chan *stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -183,17 +198,11 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan packetInfo, size), + Ch: make(chan *stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView -} - // Drain removes all outbound packets from the channel and counts them. func (e *errorChannel) Drain() int { c := 0 @@ -208,14 +217,9 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { select { - case e.Ch <- p: + case e.Ch <- pkt: default: } @@ -291,19 +295,19 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), - } + pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + source := pkt.Clone() c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, pkt) if err != nil { t.Errorf("err got %v, want %v", err, nil) } - var results []packetInfo + var results []*stack.PacketBuffer L: for { select { @@ -343,9 +347,13 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) + pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, pkt) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -451,7 +459,7 @@ func TestInvalidFragments(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - const nicid tcpip.NICID = 42 + const nicID tcpip.NICID = 42 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ ipv4.NewProtocol(), @@ -461,10 +469,12 @@ func TestInvalidFragments(t *testing.T) { var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicid, sniffer.New(ep)) + s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt})) + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + })) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { @@ -476,3 +486,423 @@ func TestInvalidFragments(t *testing.T) { }) } } + +// TestReceiveFragments feeds fragments in through the incoming packet path to +// test reassembly +func TestReceiveFragments(t *testing.T) { + const ( + nicID = 1 + + addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 + addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 + addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 + ) + + // Build and return a UDP header containing payload. + udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View { + payload := buffer.NewView(payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) * multiplier + } + + udpLength := header.UDPMinimumSize + len(payload) + + hdr := buffer.NewPrependable(udpLength) + u := header.UDP(hdr.Prepend(udpLength)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: uint16(udpLength), + }) + copy(u.Payload(), payload) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) + sum = header.Checksum(payload, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + return hdr.View() + } + + // UDP header plus a payload of 0..256 + ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2) + udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:] + ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2) + udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:] + // UDP header plus a payload of 0..256 in increments of 2. + ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2) + udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:] + // UDP header plus a payload of 0..256 in increments of 3. + // Used to test cases where the fragment blocks are not a multiple of + // the fragment block size of 8 (RFC 791 section 3.1 page 14). + ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) + udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] + + type fragmentData struct { + srcAddr tcpip.Address + dstAddr tcpip.Address + id uint16 + flags uint8 + fragmentOffset uint16 + payload buffer.View + } + + tests := []struct { + name string + fragments []fragmentData + expectedPayloads [][]byte + }{ + { + name: "No fragmentation", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2, + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "No fragmentation with size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 0, + payload: ipv4Payload3Addr1ToAddr2, + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, + }, + { + name: "More fragments without payload", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2, + }, + }, + expectedPayloads: nil, + }, + { + name: "Non-zero fragment offset without payload", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 8, + payload: ipv4Payload1Addr1ToAddr2, + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments out of order", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with last fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload3Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload3Addr1ToAddr2[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, + }, + { + name: "Two fragments with first fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload3Addr1ToAddr2[:63], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 63, + payload: ipv4Payload3Addr1ToAddr2[63:], + }, + }, + expectedPayloads: nil, + }, + { + name: "Second fragment has MoreFlags set", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with different IDs", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 2, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + }, + expectedPayloads: nil, + }, + { + name: "Two interleaved fragmented packets", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 2, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload2Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 2, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload2Addr1ToAddr2[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, + }, + { + name: "Two interleaved fragmented packets from different sources but with same ID", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr3, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr3ToAddr2[:32], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + { + srcAddr: addr3, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 32, + payload: ipv4Payload1Addr3ToAddr2[32:], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, + }, + { + name: "Fragment without followup", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + }, + expectedPayloads: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Setup a stack and endpoint. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + // Prepare and send the fragments. + for _, frag := range test.fragments { + hdr := buffer.NewPrependable(header.IPv4MinimumSize) + + // Serialize IPv4 fixed header. + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)), + ID: frag.id, + Flags: frag.flags, + FragmentOffset: frag.fragmentOffset, + TTL: 64, + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: frag.srcAddr, + DstAddr: frag.dstAddr, + }) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(frag.payload) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { + t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) + } + + for i, expectedPayload := range test.expectedPayloads { + gotPayload, _, err := ep.Read(nil) + if err != nil { + t.Fatalf("(i=%d) Read(nil): %s", i, err) + } + if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" { + t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) + } + } + + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index f06622a8b..bcc64994e 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,14 +8,12 @@ go_library( "icmp.go", "ipv6.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/stack", ], ) @@ -29,10 +26,11 @@ go_test( "ipv6_test.go", "ndp_test.go", ], - embed = [":ipv6"], + library = ":ipv6", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", @@ -40,5 +38,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index c3f1dd488..66d3a953a 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,6 +15,8 @@ package ipv6 import ( + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -25,26 +27,35 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it // is truncated, which would cause IsValid to return false. // // Drop packet if it doesn't have the basic IPv6 header or if the - // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + // original source address doesn't match an address we own. + src := hdr.SourceAddress() + if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { return } // Skip the IP header, then handle the fragmentation header if there // is one. - vv.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + pkt.Data.TrimFront(header.IPv6MinimumSize) + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(vv.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -52,145 +63,183 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip fragmentation header and find out the actual protocol // number. - vv.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := vv.First() - if len(v) < header.ICMPv6MinimumSize { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } h := header.ICMPv6(v) - iph := header.IPv6(netHeader) + iph := header.IPv6(pkt.NetworkHeader().View()) // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. - payload := vv - payload.RemoveFirst() + payload := pkt.Data.Clone(nil) + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return } - // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and - // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field - // in the IPv6 header is not set to 255. - switch h.Type() { - case header.ICMPv6NeighborSolicit, - header.ICMPv6NeighborAdvert, - header.ICMPv6RouterSolicit, - header.ICMPv6RouterAdvert, - header.ICMPv6RedirectMsg: - if iph.HopLimit() != header.NDPHopLimit { - received.Invalid.Increment() - return - } + isNDPValid := func() bool { + // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and + // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. + // + // As per RFC 6980 section 5, nodes MUST silently drop NDP messages if the + // packet includes a fragmentation header. + return !hasFragmentHeader && iph.HopLimit() == header.NDPHopLimit && h.Code() == 0 } // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + mtu := header.ICMPv6(hdr).MTU() + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + switch header.ICMPv6(hdr).Code() { + case header.ICMPv6NetworkUnreachable: + e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + received.Invalid.Increment() + return + } + + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) + it, err := ns.Options().Iter(true) + if err != nil { + // If we have a malformed NDP NS option, drop the packet. received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) targetAddr := ns.TargetAddress() s := r.Stack() - rxNICID := r.NICID() - - isTentative, err := s.IsAddrTentative(rxNICID, targetAddr) - if err != nil { - // We will only get an error if rxNICID is unrecognized, - // which should not happen. For now short-circuit this - // packet. + if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { + // We will only get an error if the NIC is unrecognized, which should not + // happen. For now, drop this packet. // // TODO(b/141002840): Handle this better? return - } - - if isTentative { - // If the target address is tentative and the source - // of the packet is a unicast (specified) address, then - // the source of the packet is attempting to perform - // address resolution on the target. In this case, the - // solicitation is silently ignored, as per RFC 4862 - // section 5.4.3. + } else if isTentative { + // If the target address is tentative and the source of the packet is a + // unicast (specified) address, then the source of the packet is + // attempting to perform address resolution on the target. In this case, + // the solicitation is silently ignored, as per RFC 4862 section 5.4.3. // - // If the target address is tentative and the source of - // the packet is the unspecified address (::), then we - // know another node is also performing DAD for the - // same address (since targetAddr is tentative for us, - // we know we are also performing DAD on it). In this - // case we let the stack know so it can handle such a - // scenario and do nothing further with the NDP NS. - if iph.SourceAddress() == header.IPv6Any { - s.DupTentativeAddrDetected(rxNICID, targetAddr) + // If the target address is tentative and the source of the packet is the + // unspecified address (::), then we know another node is also performing + // DAD for the same address (since the target address is tentative for us, + // we know we are also performing DAD on it). In this case we let the + // stack know so it can handle such a scenario and do nothing further with + // the NS. + if r.RemoteAddress == header.IPv6Any { + s.DupTentativeAddrDetected(e.nicID, targetAddr) } - // Do not handle neighbor solicitations targeted - // to an address that is tentative on the received - // NIC any further. + // Do not handle neighbor solicitations targeted to an address that is + // tentative on the NIC any further. return } - // At this point we know that targetAddr is not tentative on - // rxNICID so the packet is processed as defined in RFC 4861, - // as per RFC 4862 section 5.4.3. + // At this point we know that the target address is not tentative on the NIC + // so the packet is processed as defined in RFC 4861, as per RFC 4862 + // section 5.4.3. - if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { - // We don't have a useful answer; the best we can do is ignore the request. + // Is the NS targetting us? + if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { return } - optsSerializer := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), + // If the NS message contains the Source Link-Layer Address option, update + // the link address cache with the value of the option. + // + // TODO(b/148429853): Properly process the NS message and do Neighbor + // Unreachability Detection. + var sourceLinkAddr tcpip.LinkAddress + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) + } + if done { + break + } + + switch opt := opt.(type) { + case header.NDPSourceLinkLayerAddressOption: + // No RFCs define what to do when an NS message has multiple Source + // Link-Layer Address options. Since no interface can have multiple + // link-layer addresses, we consider such messages invalid. + if len(sourceLinkAddr) != 0 { + received.Invalid.Increment() + return + } + + sourceLinkAddr = opt.EthernetAddress() + } + } + + unspecifiedSource := r.RemoteAddress == header.IPv6Any + + // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST + // NOT be included when the source IP address is the unspecified address. + // Otherwise, on link layers that have addresses this option MUST be + // included in multicast solicitations and SHOULD be included in unicast + // solicitations. + if len(sourceLinkAddr) == 0 { + if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + received.Invalid.Increment() + return + } + } else if unspecifiedSource { + received.Invalid.Increment() + return + } else { + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(targetAddr) - opts := na.Options() - opts.Serialize(optsSerializer) // ICMPv6 Neighbor Solicit messages are always sent to // specially crafted IPv6 multicast addresses. As a result, the @@ -203,16 +252,43 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - // TODO(tamird/ghanan): there exists an explicit NDP option that is - // used to update the neighbor table with link addresses for a - // neighbor from an NS (see the Source Link Layer option RFC - // 4861 section 4.6.1 and section 7.2.3). + // As per RFC 4861 section 7.2.4, if the the source of the solicitation is + // the unspecified address, the node MUST set the Solicited flag to zero and + // multicast the advertisement to the all-nodes address. + solicited := true + if unspecifiedSource { + solicited = false + r.RemoteAddress = header.IPv6AllNodesMulticastAddress + } + + // If the NS has a source link-layer option, use the link address it + // specifies as the remote link address for the response instead of the + // source link address of the packet. // - // Furthermore, the entirety of NDP handling here seems to be - // contradicted by RFC 4861. - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link + // address cache for the right destination link address instead of manually + // patching the route with the remote link address if one is specified in a + // Source Link-Layer Address option. + if len(sourceLinkAddr) != 0 { + r.RemoteLinkAddress = sourceLinkAddr + } + + optsSerializer := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), + } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), + }) + packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) + na.SetSolicitedFlag(solicited) + na.SetOverrideFlag(true) + na.SetTargetAddress(targetAddr) + opts := na.Options() + opts.Serialize(optsSerializer) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) // @@ -220,7 +296,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.Dropped.Increment() return } @@ -228,64 +304,121 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + received.Invalid.Increment() + return + } + + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) + it, err := na.Options().Iter(true) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) targetAddr := na.TargetAddress() stack := r.Stack() - rxNICID := r.NICID() - isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr) - if err != nil { - // We will only get an error if rxNICID is unrecognized, - // which should not happen. For now short-circuit this - // packet. + if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil { + // We will only get an error if the NIC is unrecognized, which should not + // happen. For now short-circuit this packet. // // TODO(b/141002840): Handle this better? return - } - - if isTentative { - // We just got an NA from a node that owns an address we - // are performing DAD on, implying the address is not - // unique. In this case we let the stack know so it can - // handle such a scenario and do nothing furthur with + } else if isTentative { + // We just got an NA from a node that owns an address we are performing + // DAD on, implying the address is not unique. In this case we let the + // stack know so it can handle such a scenario and do nothing furthur with // the NDP NA. - stack.DupTentativeAddrDetected(rxNICID, targetAddr) + stack.DupTentativeAddrDetected(e.nicID, targetAddr) return } - // At this point we know that the targetAddress is not tentative - // on rxNICID. However, targetAddr may still be assigned to - // rxNICID but not tentative (it could be permanent). Such a - // scenario is beyond the scope of RFC 4862. As such, we simply - // ignore such a scenario for now and proceed as normal. + // At this point we know that the target address is not tentative on the + // NIC. However, the target address may still be assigned to the NIC but not + // tentative (it could be permanent). Such a scenario is beyond the scope of + // RFC 4862. As such, we simply ignore such a scenario for now and proceed + // as normal. + // + // TODO(b/143147598): Handle the scenario described above. Also inform the + // netstack integration that a duplicate address was detected outside of + // DAD. + + // If the NA message has the target link layer option, update the link + // address cache with the link address for the target of the message. // - // TODO(b/143147598): Handle the scenario described above. Also - // inform the netstack integration that a duplicate address was - // detected outside of DAD. + // TODO(b/148429853): Properly process the NA message and do Neighbor + // Unreachability Detection. + var targetLinkAddr tcpip.LinkAddress + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) + } + if done { + break + } + + switch opt := opt.(type) { + case header.NDPTargetLinkLayerAddressOption: + // No RFCs define what to do when an NA message has multiple Target + // Link-Layer Address options. Since no interface can have multiple + // link-layer addresses, we consider such messages invalid. + if len(targetLinkAddr) != 0 { + received.Invalid.Increment() + return + } + + targetLinkAddr = opt.EthernetAddress() + } + } - e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) - if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + if len(targetLinkAddr) != 0 { + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) } case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + + remoteLinkAddr := r.RemoteLinkAddress + + // As per RFC 4291 section 2.7, multicast addresses must not be used as + // source addresses in IPv6 packets. + localAddr := r.LocalAddress + if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr = "" + } + + r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // Use the link address from the source of the original packet. + r.ResolveWith(remoteLinkAddr) + + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, + Data: pkt.Data, + }) + packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) + copy(packet, icmpHdr) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -293,11 +426,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -307,12 +440,64 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv6RouterSolicit: received.RouterSolicit.Increment() + if !isNDPValid() { + received.Invalid.Increment() + return + } case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + received.Invalid.Increment() + return + } + + routerAddr := iph.SourceAddress() + + // + // Validate the RA as per RFC 4861 section 6.1.2. + // + + // Is the IP Source Address a link-local address? + if !header.IsV6LinkLocalAddress(routerAddr) { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) + opts := ra.Options() + + // Are options valid as per the wire format? + if _, err := opts.Iter(true); err != nil { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // + // At this point, we have a valid Router Advertisement, as far + // as RFC 4861 section 6.1.2 is concerned. + // + + // Tell the NIC to handle the RA. + stack := r.Stack() + rxNICID := r.NICID() + stack.HandleNDPRA(rxNICID, routerAddr, ra) + case header.ICMPv6RedirectMsg: received.RedirectMsg.Increment() + if !isNDPValid() { + received.Invalid.Increment() + return + } default: received.Invalid.Increment() @@ -331,8 +516,6 @@ const ( icmpV6LengthOffset = 25 ) -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - var _ stack.LinkAddressResolver = (*protocol)(nil) // LinkAddressProtocol implements stack.LinkAddressResolver. @@ -341,24 +524,34 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { snaddr := header.SolicitedNodeAddr(addr) + + // TODO(b/148672031): Use stack.FindRoute instead of manually creating the + // route here. Note, we would need the nicID to do this properly so the right + // NIC (associated to linkEP) is used to send the NDP NS message. r := &stack.Route{ LocalAddress: localAddr, RemoteAddress: snaddr, - RemoteLinkAddress: broadcastMAC, + RemoteLinkAddress: remoteLinkAddr, + } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - copy(pkt[icmpV6OptOffset-len(addr):], addr) - pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - length := uint16(hdr.UsedLength()) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, + }) + icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + icmpHdr.SetType(header.ICMPv6NeighborSolicit) + copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) + icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr + icmpHdr[icmpV6LengthOffset] = 1 + copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -368,29 +561,13 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if header.IsV6MulticastAddress(addr) { - // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks - // - // 7. Address Mapping -- Multicast - // - // An IPv6 packet with a multicast destination address DST, - // consisting of the sixteen octets DST[1] through DST[16], is - // transmitted to the Ethernet multicast address whose first - // two octets are the value 3333 hexadecimal and whose last - // four octets are the last four octets of DST. - return tcpip.LinkAddress([]byte{ - 0x33, - 0x33, - addr[header.IPv6AddressSize-4], - addr[header.IPv6AddressSize-3], - addr[header.IPv6AddressSize-2], - addr[header.IPv6AddressSize-1], - }), true + return header.EthernetAddressFromMulticastIPv6Address(addr), true } - return "", false + return tcpip.LinkAddress([]byte(nil)), false } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index b112303b6..9e4eeea77 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "context" "reflect" "strings" "testing" @@ -31,7 +32,11 @@ import ( const ( linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + + defaultChannelSize = 1 + defaultMTU = 65536 ) var ( @@ -55,7 +60,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return nil } @@ -65,7 +70,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) { } type stubLinkAddressCache struct { @@ -109,10 +114,8 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } + ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) + defer ep.Close() r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -120,48 +123,90 @@ func TestICMPCounts(t *testing.T) { } defer r.Release() + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { - typ header.ICMPv6Type - size int + typ header.ICMPv6Type + size int + extraData []byte }{ - {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize}, - {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize}, - {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize}, - {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize}, - {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, - {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, - {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, - {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, - } - - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + { + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + }, + { + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + }, + { + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + }, + { + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, + { + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + }, + { + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + }, + } + + handleIPv6Payload := func(icmp header.ICMPv6) { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + ep.HandlePacket(&r, pkt) } for _, typ := range types { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - handleIPv6Payload(hdr) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -214,8 +259,7 @@ func newTestContext(t *testing.T) *testContext { }), } - const defaultMTU = 65536 - c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { @@ -228,7 +272,7 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress lladdr0: %v", err) } - c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -262,32 +306,40 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP0.C) - close(c.linkEP1.C) + c.linkEP0.Close() + c.linkEP1.Close() } type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type + src, dst *channel.Endpoint + typ header.ICMPv6Type + remoteLinkAddr tcpip.LinkAddress } func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pkt := <-args.src.C + pi, _ := args.src.ReadContext(context.Background()) { - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), + }) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) } - if pkt.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pkt.Proto) + if pi.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pi.Proto) return } - ipv6 := header.IPv6(pkt.Header) + + if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { + t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) + } + + // Pull the full payload since network header. Needed for header.IPv6 to + // extract its payload. + ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -334,7 +386,7 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) } for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit}, + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, } { routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { @@ -361,97 +413,104 @@ func TestLinkResolution(t *testing.T) { } func TestICMPChecksumValidationSimple(t *testing.T) { + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { name string typ header.ICMPv6Type size int + extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "DstUnreachable", + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.DstUnreachable }, }, { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "PacketTooBig", + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.PacketTooBig }, }, { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "TimeExceeded", + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.TimeExceeded }, }, { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "ParamProblem", + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.ParamProblem }, }, { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "EchoRequest", + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.EchoRequest }, }, { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "EchoReply", + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.EchoReply }, }, { - "RouterSolicit", - header.ICMPv6RouterSolicit, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }, }, { - "RouterAdvert", - header.ICMPv6RouterAdvert, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }, }, { - "NeighborSolicit", - header.ICMPv6NeighborSolicit, - header.ICMPv6NeighborSolicitMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.NeighborSolicit }, }, { - "NeighborAdvert", - header.ICMPv6NeighborAdvert, - header.ICMPv6NeighborAdvertSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.NeighborAdvert }, }, { - "RedirectMsg", - header.ICMPv6RedirectMsg, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RedirectMsg }, }, @@ -483,22 +542,25 @@ func TestICMPChecksumValidationSimple(t *testing.T) { ) } - handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + handleIPv6Payload := func(checksum bool) { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), + }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -515,7 +577,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { // Without setting checksum, the incoming packet should // be invalid. - handleIPv6Payload(typ.typ, typ.size, false) + handleIPv6Payload(false) if got := invalid.Value(); got != 1 { t.Fatalf("got invalid = %d, want = 1", got) } @@ -525,7 +587,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, true) + handleIPv6Payload(true) if got := typStat.Value(); got != 1 { t.Fatalf("got %s = %d, want = 1", typ.name, got) } @@ -657,12 +719,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { icmpSize := size + payloadSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) + icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) + icmpHdr.SetType(typ) + payloadFn(icmpHdr.Payload()) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -673,7 +735,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -831,14 +896,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + icmpHdr := header.ICMPv6(hdr.Prepend(size)) + icmpHdr.SetType(typ) payload := buffer.NewView(payloadSize) payloadFn(payload) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -849,9 +914,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, - buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, - []buffer.View{hdr.View(), payload})) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -889,3 +955,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { }) } } + +func TestLinkAddressRequest(t *testing.T) { + snaddr := header.SolicitedNodeAddr(lladdr0) + mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) + + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: linkAddr1, + expectLinkAddr: linkAddr1, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: mcaddr, + }, + } + + for _, test := range tests { + p := NewProtocol() + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 5898f8f9e..0eafe9790 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -21,11 +21,13 @@ package ipv6 import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -43,13 +45,12 @@ const ( ) type endpoint struct { - nicid tcpip.NICID - id stack.NetworkEndpointID - prefixLen int + nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher protocol *protocol + stack *stack.Stack } // DefaultTTL is the default hop limit for this endpoint. @@ -65,17 +66,7 @@ func (e *endpoint) MTU() uint32 { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid -} - -// ID returns the ipv6 endpoint ID. -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &e.id -} - -// PrefixLen returns the ipv6 endpoint subnet prefix length in bits. -func (e *endpoint) PrefixLen() int { - return e.prefixLen + return e.nicID } // Capabilities implements stack.NetworkEndpoint.Capabilities. @@ -97,9 +88,9 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) { - length := uint16(hdr.UsedLength() + payloadSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(params.Protocol), @@ -108,86 +99,336 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - e.addIPHeader(r, &hdr, payload.Size(), params) - - if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { + e.addIPHeader(r, pkt, params) + + if r.Loop&stack.PacketLoop != 0 { loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // The inbound path expects an unparsed packet. + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } - if loop&stack.PacketOut == 0 { - return len(hdrs), nil + if r.Loop&stack.PacketOut == 0 { + return pkts.Len(), nil } - for i := range hdrs { - hdr := &hdrs[i].Hdr - size := hdrs[i].Size - e.addIPHeader(r, hdr, size, params) + for pb := pkts.Front(); pb != nil; pb = pb.Next() { + e.addIPHeader(r, pb, params) } - n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber) + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { - // TODO(b/119580726): Support IPv6 header-included packets. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { + // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() - h := header.IPv6(headerView) - if !h.IsValid(vv.Size()) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.IPv6(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { + r.Stats().IP.MalformedPacketsReceived.Increment() return } - vv.TrimFront(header.IPv6MinimumSize) - vv.CapLength(int(h.PayloadLength())) - - p := h.TransportProtocol() - if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, vv) - return + // vv consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader().View()) + vv.Append(pkt.Data) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) + hasFragmentHeader := false + + for firstHeader := true; ; firstHeader = false { + extHdr, done, err := it.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6HopByHopOptionsExtHdr: + // As per RFC 8200 section 4.1, the Hop By Hop extension header is + // restricted to appear immediately after an IPv6 fixed header. + // + // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 + // (unrecognized next header) error in response to an extension header's + // Next Header field with the Hop By Hop extension header identifier. + if !firstHeader { + return + } + + optsIt := extHdr.Iter() + + for { + opt, done, err := optsIt.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + // We currently do not support any IPv6 Hop By Hop extension header + // options. + switch opt.UnknownAction() { + case header.IPv6OptionUnknownActionSkip: + case header.IPv6OptionUnknownActionDiscard: + return + case header.IPv6OptionUnknownActionDiscardSendICMP: + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for + // unrecognized IPv6 extension header options. + return + case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for + // unrecognized IPv6 extension header options. + return + default: + panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) + } + } + + case header.IPv6RoutingExtHdr: + // As per RFC 8200 section 4.4, if a node encounters a routing header with + // an unrecognized routing type value, with a non-zero Segments Left + // value, the node must discard the packet and send an ICMP Parameter + // Problem, Code 0. If the Segments Left is 0, the node must ignore the + // Routing extension header and process the next header in the packet. + // + // Note, the stack does not yet handle any type of routing extension + // header, so we just make sure Segments Left is zero before processing + // the next extension header. + // + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for + // unrecognized routing types with a non-zero Segments Left value. + if extHdr.SegmentsLeft() != 0 { + return + } + + case header.IPv6FragmentExtHdr: + hasFragmentHeader = true + + if extHdr.IsAtomic() { + // This fragment extension header indicates that this packet is an + // atomic fragment. An atomic fragment is a fragment that contains + // all the data required to reassemble a full packet. As per RFC 6946, + // atomic fragments must not interfere with "normal" fragmented traffic + // so we skip processing the fragment instead of feeding it through the + // reassembly process below. + continue + } + + // Don't consume the iterator if we have the first fragment because we + // will use it to validate that the first fragment holds the upper layer + // header. + rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */) + + if extHdr.FragmentOffset() == 0 { + // Check that the iterator ends with a raw payload as the first fragment + // should include all headers up to and including any upper layer + // headers, as per RFC 8200 section 4.5; only upper layer data + // (non-headers) should follow the fragment extension header. + var lastHdr header.IPv6PayloadHeader + + for { + it, done, err := it.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + lastHdr = it + } + + // If the last header is a raw header, then the last portion of the IPv6 + // payload is not a known IPv6 extension header. Note, this does not + // mean that the last portion is an upper layer header or not an + // extension header because: + // 1) we do not yet support all extension headers + // 2) we do not validate the upper layer header before reassembling. + // + // This check makes sure that a known IPv6 extension header is not + // present after the Fragment extension header in a non-initial + // fragment. + // + // TODO(#2196): Support IPv6 Authentication and Encapsulated + // Security Payload extension headers. + // TODO(#2333): Validate that the upper layer header is valid. + switch lastHdr.(type) { + case header.IPv6RawPayloadHeader: + default: + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + } + + fragmentPayloadLen := rawPayload.Buf.Size() + if fragmentPayloadLen == 0 { + // Drop the packet as it's marked as a fragment but has no payload. + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + // The packet is a fragment, let's try to reassemble it. + start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + last := start + uint16(fragmentPayloadLen) - 1 + + // Drop the packet if the fragmentOffset is incorrect. i.e the + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. + if last < start { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + var ready bool + // Note that pkt doesn't have its transport header set after reassembly, + // and won't until DeliverNetworkPacket sets it. + pkt.Data, ready, err = e.protocol.fragmentation.Process( + // IPv6 ignores the Protocol field since the ID only needs to be unique + // across source-destination pairs, as per RFC 8200 section 4.5. + fragmentation.FragmentID{ + Source: h.SourceAddress(), + Destination: h.DestinationAddress(), + ID: extHdr.ID(), + }, + start, + last, + extHdr.More(), + rawPayload.Buf, + ) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + if ready { + // We create a new iterator with the reassembled packet because we could + // have more extension headers in the reassembled payload, as per RFC + // 8200 section 4.5. + it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data) + } + + case header.IPv6DestinationOptionsExtHdr: + optsIt := extHdr.Iter() + + for { + opt, done, err := optsIt.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + // We currently do not support any IPv6 Destination extension header + // options. + switch opt.UnknownAction() { + case header.IPv6OptionUnknownActionSkip: + case header.IPv6OptionUnknownActionDiscard: + return + case header.IPv6OptionUnknownActionDiscardSendICMP: + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for + // unrecognized IPv6 extension header options. + return + case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for + // unrecognized IPv6 extension header options. + return + default: + panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) + } + } + + case header.IPv6RawPayloadHeader: + // If the last header in the payload isn't a known IPv6 extension header, + // handle it as if it is transport layer data. + + // For unfragmented packets, extHdr still contains the transport header. + // Get rid of it. + // + // For reassembled fragments, pkt.TransportHeader is unset, so this is a + // no-op and pkt.Data begins with the transport header. + extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) + pkt.Data = extHdr.Buf + + if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { + e.handleICMP(r, pkt, hasFragmentHeader) + } else { + r.Stats().IP.PacketsDelivered.Increment() + // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error + // in response to unrecognized next header values. + e.dispatcher.DeliverTransportPacket(r, p, pkt) + } + + default: + // If we receive a packet for an extension header we do not yet handle, + // drop the packet for now. + // + // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error + // in response to unrecognized next header values. + r.Stats().UnknownProtocolRcvdPackets.Increment() + return + } } - - r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) } // Close cleans up resources associated with the endpoint. func (*endpoint) Close() {} +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + type protocol struct { // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful and it must be accessed // atomically. - defaultTTL uint32 + defaultTTL uint32 + fragmentation *fragmentation.Fragmentation } // Number returns the ipv6 protocol number. @@ -212,16 +453,15 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { return &endpoint{ - nicid: nicid, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, + nicID: nicID, linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, protocol: p, - }, nil + stack: st, + } } // SetOption implements NetworkProtocol.SetOption. @@ -256,6 +496,83 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + // + // Parsing occurs again in HandlePacket because we don't track the + // extensions in PacketBuffer. Unfortunately, that means HandlePacket + // has to do the parsing work again. + var nextHdr tcpip.TransportProtocolNumber + foundNext := true + extensionsSize := 0 +traverseExtensions: + for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { + if err != nil { + break + } + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + foundNext = false + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + // If this is an atomic fragment, we don't have to treat it specially. + if !extHdr.More() && extHdr.FragmentOffset() == 0 { + continue + } + // This is a non-atomic fragment and has to be re-assembled before we can + // examine the payload for a transport header. + foundNext = false + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + + return nextHdr, foundNext, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { @@ -268,5 +585,8 @@ func calculateMTU(mtu uint32) uint32 { // NewProtocol returns an IPv6 network protocol. func NewProtocol() stack.NetworkProtocol { - return &protocol{defaultTTL: DefaultTTL} + return &protocol{ + defaultTTL: DefaultTTL, + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + } } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index deaa9b7f3..0a183bfde 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -17,6 +17,7 @@ package ipv6 import ( "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -33,6 +34,15 @@ const ( // The least significant 3 bytes are the same as addr2 so both addr2 and // addr3 will have the same solicited-node address. addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" + addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" + + // Tests use the extension header identifier values as uint8 instead of + // header.IPv6ExtensionHeaderIdentifier. + hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier) + routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier) + fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) + destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) + noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) ) // testReceiveICMP tests receiving an ICMP packet from src to dst. want is the @@ -55,7 +65,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) stats := s.Stats().ICMP.V6PacketsReceived @@ -111,7 +123,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) stat := s.Stats().UDP.PacketsReceived @@ -154,6 +168,8 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // packets destined to the IPv6 solicited-node address of an assigned IPv6 // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + const nicID = 1 + tests := []struct { name string protocolFactory stack.TransportProtocol @@ -171,50 +187,61 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as we haven't added - // those addresses. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + // Should not receive a packet destined to the solicited node address of + // addr2/addr3 yet as we haven't added those addresses. test.rxf(t, s, e, addr1, snmc, 0) - if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } - // Should receive a packet destined to the solicited - // node address of addr2/addr3 now that we have added - // added addr2. + // Should receive a packet destined to the solicited node address of + // addr2/addr3 now that we have added added addr2. test.rxf(t, s, e, addr1, snmc, 1) - if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) } - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have added addr3. + // Should still receive a packet destined to the solicited node address of + // addr2/addr3 now that we have added addr3. test.rxf(t, s, e, addr1, snmc, 2) - if err := s.RemoveAddress(1, addr2); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + if err := s.RemoveAddress(nicID, addr2); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err) } - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have removed addr2. + // Should still receive a packet destined to the solicited node address of + // addr2/addr3 now that we have removed addr2. test.rxf(t, s, e, addr1, snmc, 3) - if err := s.RemoveAddress(1, addr3); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + // Make sure addr3's endpoint does not get removed from the NIC by + // incrementing its reference count with a route. + r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err) + } + defer r.Release() + + if err := s.RemoveAddress(nicID, addr3); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err) } - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as both of them got - // removed. + // Should not receive a packet destined to the solicited node address of + // addr2/addr3 yet as both of them got removed, even though a route using + // addr3 exists. test.rxf(t, s, e, addr1, snmc, 3) }) } @@ -264,3 +291,1244 @@ func TestAddIpv6Address(t *testing.T) { }) } } + +func TestReceiveIPv6ExtHdrs(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + extHdr func(nextHdr uint8) ([]byte, uint8) + shouldAccept bool + }{ + { + name: "None", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, + shouldAccept: true, + }, + { + name: "hopbyhop with unknown option skippable action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 62, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID + }, + shouldAccept: true, + }, + { + name: "hopbyhop with unknown option discard action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard unknown. + 127, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + { + name: "hopbyhop with unknown option discard and send icmp action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + { + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + shouldAccept: true, + }, + { + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + shouldAccept: false, + }, + { + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + shouldAccept: true, + }, + { + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + shouldAccept: true, + }, + { + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + shouldAccept: false, + }, + { + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + shouldAccept: false, + }, + { + name: "destination with unknown option skippable action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 62, 6, 1, 2, 3, 4, 5, 6, + }, destinationExtHdrID + }, + shouldAccept: true, + }, + { + name: "destination with unknown option discard action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard unknown. + 127, 6, 1, 2, 3, 4, 5, 6, + }, destinationExtHdrID + }, + shouldAccept: false, + }, + { + name: "destination with unknown option discard and send icmp action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, destinationExtHdrID + }, + shouldAccept: false, + }, + { + name: "destination with unknown option discard and send icmp action unless multicast dest", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, destinationExtHdrID + }, + shouldAccept: false, + }, + { + name: "routing - atomic fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + nextHdr, 0, 0, 0, 1, 2, 3, 4, + }, routingExtHdrID + }, + shouldAccept: true, + }, + { + name: "atomic fragment - routing", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Fragment extension header. + routingExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Routing extension header. + nextHdr, 0, 1, 0, 2, 3, 4, 5, + }, fragmentExtHdrID + }, + shouldAccept: true, + }, + { + name: "hop by hop (with skippable unknown) - routing", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Hop By Hop extension header with skippable unknown option. + routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, + + // Routing extension header. + nextHdr, 0, 1, 0, 2, 3, 4, 5, + }, hopByHopExtHdrID + }, + shouldAccept: true, + }, + { + name: "routing - hop by hop (with skippable unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Hop By Hop extension header with skippable unknown option. + nextHdr, 0, 62, 4, 1, 2, 3, 4, + }, routingExtHdrID + }, + shouldAccept: false, + }, + { + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + shouldAccept: false, + }, + { + name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Hop By Hop extension header with skippable unknown option. + routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, + + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Destination extension header with skippable unknown option. + nextHdr, 0, 63, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: true, + }, + { + name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Hop By Hop extension header with discard action for unknown option. + routingExtHdrID, 0, 65, 4, 1, 2, 3, 4, + + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Destination extension header with skippable unknown option. + nextHdr, 0, 63, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + { + name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Hop By Hop extension header with skippable unknown option. + routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, + + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Destination extension header with discard action for unknown + // option. + nextHdr, 0, 65, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8} + udpLength := header.UDPMinimumSize + len(udpPayload) + extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber)) + extHdrLen := len(extHdrBytes) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength) + + // Serialize UDP message. + u := header.UDP(hdr.Prepend(udpLength)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: uint16(udpLength), + }) + copy(u.Payload(), udpPayload) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum = header.Checksum(udpPayload, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + // Copy extension header bytes between the UDP message and the IPv6 + // fixed header. + copy(hdr.Prepend(extHdrLen), extHdrBytes) + + // Serialize IPv6 fixed header. + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: ipv6NextHdr, + HopLimit: 255, + SrcAddr: addr1, + DstAddr: addr2, + }) + + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + + stats := s.Stats().UDP.PacketsReceived + + if !test.shouldAccept { + if got := stats.Value(); got != 0 { + t.Errorf("got UDP Rx Packets = %d, want = 0", got) + } + + return + } + + // Expect a UDP packet. + if got := stats.Value(); got != 1 { + t.Errorf("got UDP Rx Packets = %d, want = 1", got) + } + gotPayload, _, err := ep.Read(nil) + if err != nil { + t.Fatalf("Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + + // Should not have any more UDP packets. + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } +} + +// fragmentData holds the IPv6 payload for a fragmented IPv6 packet. +type fragmentData struct { + srcAddr tcpip.Address + dstAddr tcpip.Address + nextHdr uint8 + data buffer.VectorisedView +} + +func TestReceiveIPv6Fragments(t *testing.T) { + const ( + nicID = 1 + udpPayload1Length = 256 + udpPayload2Length = 128 + // Used to test cases where the fragment blocks are not a multiple of + // the fragment block size of 8 (RFC 8200 section 4.5). + udpPayload3Length = 127 + fragmentExtHdrLen = 8 + // Note, not all routing extension headers will be 8 bytes but this test + // uses 8 byte routing extension headers for most sub tests. + routingExtHdrLen = 8 + ) + + udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View { + payloadLen := len(payload) + for i := 0; i < payloadLen; i++ { + payload[i] = uint8(i) * multiplier + } + + udpLength := header.UDPMinimumSize + payloadLen + + hdr := buffer.NewPrependable(udpLength) + u := header.UDP(hdr.Prepend(udpLength)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: uint16(udpLength), + }) + copy(u.Payload(), payload) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) + sum = header.Checksum(payload, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + return hdr.View() + } + + var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte + udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:] + ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2) + + var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte + udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:] + ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2) + + var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte + udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:] + ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2) + + var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte + udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:] + ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2) + + tests := []struct { + name string + expectedPayload []byte + fragments []fragmentData + expectedPayloads [][]byte + }{ + { + name: "No fragmentation", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: uint8(header.UDPProtocolNumber), + data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Atomic fragment", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), + []buffer.View{ + // Fragment extension header. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + + ipv6Payload1Addr1ToAddr2, + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Atomic fragment with size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), + []buffer.View{ + // Fragment extension header. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + + ipv6Payload3Addr1ToAddr2, + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, + }, + { + name: "Two fragments", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments out of order", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with last fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, + }, + { + name: "Two fragments with first fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+63, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[:63], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[63:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with different IDs", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with per-fragment routing header with zero segments left", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with per-fragment routing header with non-zero segments left", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with routing header with zero segments left", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with routing header with non-zero segments left", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with routing header with zero segments left across fragments", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is fragmentExtHdrLen+8 because the + // first 8 bytes of the 16 byte routing extension header is in + // this fragment. + fragmentExtHdrLen+8, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header (part 1) + // + // Segments left = 0. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}), + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is + // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of + // the 16 byte routing extension header is in this fagment. + fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 1, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + + // Routing extension header (part 2) + buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + + ipv6Payload1Addr1ToAddr2, + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with routing header with non-zero segments left across fragments", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is fragmentExtHdrLen+8 because the + // first 8 bytes of the 16 byte routing extension header is in + // this fragment. + fragmentExtHdrLen+8, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header (part 1) + // + // Segments left = 1. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}), + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is + // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of + // the 16 byte routing extension header is in this fagment. + fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 1, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + + // Routing extension header (part 2) + buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + + ipv6Payload1Addr1ToAddr2, + }, + ), + }, + }, + expectedPayloads: nil, + }, + // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal" + // fragmented traffic. + { + name: "Two fragments with atomic", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + // This fragment has the same ID as the other fragments but is an atomic + // fragment. It should not interfere with the other fragments. + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), + + ipv6Payload2Addr1ToAddr2, + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2}, + }, + { + name: "Two interleaved fragmented packets", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), + + ipv6Payload2Addr1ToAddr2[:32], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 4, More = false, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), + + ipv6Payload2Addr1ToAddr2[32:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, + }, + { + name: "Two interleaved fragmented packets from different sources but with same ID", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr3, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr3ToAddr2[:32], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + { + srcAddr: addr3, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 4, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}), + + ipv6Payload1Addr3ToAddr2[32:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + + // Serialize IPv6 fixed header. + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(f.data.Size()), + NextHeader: f.nextHdr, + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, + }) + + vv := hdr.View().ToVectorisedView() + vv.Append(f.data) + + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { + t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) + } + + for i, p := range test.expectedPayloads { + gotPayload, _, err := ep.Read(nil) + if err != nil { + t.Fatalf("(i=%d) Read(nil): %s", i, err) + } + if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" { + t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) + } + } + + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index c32716f2e..af71a7d6b 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -20,7 +20,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) @@ -61,17 +63,476 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } + ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) return s, ep } -// TestHopLimitValidation is a test that makes sure that NDP packets are only -// received if their IP header's hop limit is set to 255. -func TestHopLimitValidation(t *testing.T) { +// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a +// valid NDP NS message with the Source Link Layer Address option results in a +// new entry in the link address cache for the sender of the message. +func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if linkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) + } + + if test.expectedLinkAddr != "" { + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +func TestNeighorSolicitationResponse(t *testing.T) { + const nicID = 1 + nicAddr := lladdr0 + remoteAddr := lladdr1 + nicAddrSNMC := header.SolicitedNodeAddr(nicAddr) + nicLinkAddr := linkAddr0 + remoteLinkAddr0 := linkAddr1 + remoteLinkAddr1 := linkAddr2 + + tests := []struct { + name string + nsOpts header.NDPOptionsSerializer + nsSrcLinkAddr tcpip.LinkAddress + nsSrc tcpip.Address + nsDst tcpip.Address + nsInvalid bool + naDstLinkAddr tcpip.LinkAddress + naSolicited bool + naSrc tcpip.Address + naDst tcpip.Address + }{ + { + name: "Unspecified source to multicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: false, + naSrc: nicAddr, + naDst: header.IPv6AllNodesMulticastAddress, + }, + { + name: "Unspecified source with source ll option to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + { + name: "Unspecified source to unicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: false, + naSrc: nicAddr, + naDst: header.IPv6AllNodesMulticastAddress, + }, + { + name: "Unspecified source with source ll option to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddr, + nsInvalid: true, + }, + + { + name: "Specified source with 1 source ll to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll different from route to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr1, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source to multicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + { + name: "Specified source with 2 source ll to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + + { + name: "Specified source to unicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll different from route to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr1, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 2 source ll to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(1, 1280, nicLinkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(nicAddr) + opts := ns.Options() + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) + } + + // If we expected the NS to be invalid, we have nothing else to check. + return + } + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + p, got := e.Read() + if !got { + t.Fatal("expected an NDP NA response") + } + + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) + }) + } +} + +// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a +// valid NDP NA message with the Target Link Layer Address option results in a +// new entry in the link address cache for the target of the message. +func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, + }, + { + name: "Multiple", + optsBuf: []byte{ + 2, 1, 2, 3, 4, 5, 6, 7, + 2, 1, 2, 3, 4, 5, 6, 8, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr1) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if linkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) + } + + if test.expectedLinkAddr != "" { + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +func TestNDPValidation(t *testing.T) { setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { t.Helper() @@ -87,94 +548,357 @@ func TestHopLimitValidation(t *testing.T) { return s, ep, r } - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + nextHdr := uint8(header.ICMPv6ProtocolNumber) + var extensions buffer.View + if atomicFragment { + extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) + extensions[0] = nextHdr + nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + Data: payload.ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), + PayloadLength: uint16(len(payload) + len(extensions)), + NextHeader: nextHdr, HopLimit: hopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, hdr.View().ToVectorisedView()) + if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { + t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) + } + ep.HandlePacket(r, pkt) } + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { name string typ header.ICMPv6Type size int + extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ - {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }}, - {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }}, - {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }}, - {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }}, + { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code header.ICMPv6Code + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, } for _, typ := range types { t.Run(typ.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep, r := setup(t) + defer r.Release() - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } - // Should not have received any ICMPv6 packets with - // type = typ.typ. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } + + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 + } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } + }) } + }) + } +} + +// TestRouterAdvertValidation tests that when the NIC is configured to handle +// NDP Router Advertisement packets, it validates the Router Advertisement +// properly before handling them. +func TestRouterAdvertValidation(t *testing.T) { + tests := []struct { + name string + src tcpip.Address + hopLimit uint8 + code header.ICMPv6Code + ndpPayload []byte + expectedSuccess bool + }{ + { + "OK", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + true, + }, + { + "NonLinkLocalSourceAddr", + addr1, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "HopLimitNot255", + lladdr0, + 254, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NonZeroCode", + lladdr0, + 255, + 1, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NDPPayloadTooSmall", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, + }, + false, + }, + { + "OKWithOptions", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + 2, 1, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + true, + }, + { + "OptionWithZeroLength", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + // Invalid as it has 0 length. + 2, 0, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + false, + }, + } - // Receive the NDP packet with an invalid hop limit - // value. - handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) } - // Rx count of NDP packet of type typ.typ should not - // have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) } - // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - // Rx count of NDP packet of type typ.typ should have - // increased. - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } } }) } diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 11efb4e44..2bad05a2e 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,14 +1,13 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "ports", srcs = ["ports.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", ], ) @@ -16,7 +15,7 @@ go_library( go_test( name = "ports_test", srcs = ["ports_test.go"], - embed = [":ports"], + library = ":ports", deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 30cea8996..f6d592eb5 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -18,9 +18,9 @@ package ports import ( "math" "math/rand" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -41,6 +41,46 @@ type portDescriptor struct { port uint16 } +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool + + // TupleOnly represents TCP SO_REUSEADDR. + TupleOnly bool +} + +// Bits converts the Flags to their bitset form. +func (f Flags) Bits() BitFlags { + var rf BitFlags + if f.MostRecent { + rf |= MostRecentFlag + } + if f.LoadBalanced { + rf |= LoadBalancedFlag + } + if f.TupleOnly { + rf |= TupleOnlyFlag + } + return rf +} + +// Effective returns the effective behavior of a flag config. +func (f Flags) Effective() Flags { + e := f + if e.LoadBalanced && e.MostRecent { + e.MostRecent = false + } + return e +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex @@ -54,9 +94,144 @@ type PortManager struct { hint uint32 } -type portNode struct { - reuse bool - refs int +// BitFlags is a bitset representation of Flags. +type BitFlags uint32 + +const ( + // MostRecentFlag represents Flags.MostRecent. + MostRecentFlag BitFlags = 1 << iota + + // LoadBalancedFlag represents Flags.LoadBalanced. + LoadBalancedFlag + + // TupleOnlyFlag represents Flags.TupleOnly. + TupleOnlyFlag + + // nextFlag is the value that the next added flag will have. + // + // It is used to calculate FlagMask below. It is also the number of + // valid flag states. + nextFlag + + // FlagMask is a bit mask for BitFlags. + FlagMask = nextFlag - 1 + + // MultiBindFlagMask contains the flags that allow binding the same + // tuple multiple times. + MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag +) + +// ToFlags converts the bitset into a Flags struct. +func (f BitFlags) ToFlags() Flags { + return Flags{ + MostRecent: f&MostRecentFlag != 0, + LoadBalanced: f&LoadBalancedFlag != 0, + TupleOnly: f&TupleOnlyFlag != 0, + } +} + +// FlagCounter counts how many references each flag combination has. +type FlagCounter struct { + // refs stores the count for each possible flag combination, (0 though + // FlagMask). + refs [nextFlag]int +} + +// AddRef increases the reference count for a specific flag combination. +func (c *FlagCounter) AddRef(flags BitFlags) { + c.refs[flags]++ +} + +// DropRef decreases the reference count for a specific flag combination. +func (c *FlagCounter) DropRef(flags BitFlags) { + c.refs[flags]-- +} + +// TotalRefs calculates the total number of references for all flag +// combinations. +func (c FlagCounter) TotalRefs() int { + var total int + for _, r := range c.refs { + total += r + } + return total +} + +// FlagRefs returns the number of references with all specified flags. +func (c FlagCounter) FlagRefs(flags BitFlags) int { + var total int + for i, r := range c.refs { + if BitFlags(i)&flags == flags { + total += r + } + } + return total +} + +// AllRefsHave returns if all references have all specified flags. +func (c FlagCounter) AllRefsHave(flags BitFlags) bool { + for i, r := range c.refs { + if BitFlags(i)&flags != flags && r > 0 { + return false + } + } + return true +} + +// IntersectionRefs returns the set of flags shared by all references. +func (c FlagCounter) IntersectionRefs() BitFlags { + intersection := FlagMask + for i, r := range c.refs { + if r > 0 { + intersection &= BitFlags(i) + } + } + return intersection +} + +type destination struct { + addr tcpip.Address + port uint16 +} + +func makeDestination(a tcpip.FullAddress) destination { + return destination{ + a.Addr, + a.Port, + } +} + +// portNode is never empty. When it has no elements, it is removed from the +// map that references it. +type portNode map[destination]FlagCounter + +// intersectionRefs calculates the intersection of flag bit values which affect +// the specified destination. +// +// If no destinations are present, all flag values are returned as there are no +// entries to limit possible flag values of a new entry. +// +// In addition to the intersection, the number of intersecting refs is +// returned. +func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { + intersection := FlagMask + var count int + + for d, f := range p { + if d == dst { + intersection &= f.IntersectionRefs() + count++ + continue + } + // Wildcard destinations affect all destinations for TupleOnly. + if d.addr == anyIPAddress || dst.addr == anyIPAddress { + // Only bitwise and the TupleOnlyFlag. + intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs()) + count++ + } + } + + return intersection, count } // deviceNode is never empty. When it has no elements, it is removed from the @@ -64,32 +239,45 @@ type portNode struct { type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a -// device, check against all portNodes. If binding to a specific device, check +// device, check against all FlagCounters. If binding to a specific device, check // against the unspecified device and the provided device. -func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { +// +// If either of the port reuse flags is enabled on any of the nodes, all nodes +// sharing a port must share at least one reuse flag. This matches Linux's +// behavior. +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool { + flagBits := flags.Bits() if bindToDevice == 0 { - // Trying to binding all devices. - if !reuse { - // Can't bind because the (addr,port) is already bound. - return false - } + intersection := FlagMask for _, p := range d { - if !p.reuse { - // Can't bind because the (addr,port) was previously bound without reuse. + i, c := p.intersectionRefs(dst) + if c == 0 { + continue + } + intersection &= i + if intersection&flagBits == 0 { + // Can't bind because the (addr,port) was + // previously bound without reuse. return false } } return true } + intersection := FlagMask + if p, ok := d[0]; ok { - if !reuse || !p.reuse { + var c int + intersection, c = p.intersectionRefs(dst) + if c > 0 && intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - if !reuse || !p.reuse { + i, c := p.intersectionRefs(dst) + intersection &= i + if c > 0 && intersection&flagBits == 0 { return false } } @@ -103,12 +291,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -117,14 +305,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -190,17 +378,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest)) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice, dst) { return false } } @@ -212,14 +400,16 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() + dst := makeDestination(dest) + // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { return 0, tcpip.ErrPortInUse } return port, nil @@ -227,16 +417,18 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) { return false } + flagBits := flags.Bits() + // Reserve port on all network protocols. for _, network := range networks { desc := portDescriptor{network, transport, port} @@ -250,12 +442,65 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - if n, ok := d[bindToDevice]; ok { - n.refs++ - d[bindToDevice] = n - } else { - d[bindToDevice] = portNode{reuse: reuse, refs: 1} + p := d[bindToDevice] + if p == nil { + p = make(portNode) } + n := p[dst] + n.AddRef(flagBits) + p[dst] = n + d[bindToDevice] = p + } + + return true +} + +// ReserveTuple adds a port reservation for the tuple on all given protocol. +func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { + flagBits := flags.Bits() + dst := makeDestination(dest) + + s.mu.Lock() + defer s.mu.Unlock() + + // It is easier to undo the entire reservation, so if we find that the + // tuple can't be fully added, finish and undo the whole thing. + undo := false + + // Reserve port on all network protocols. + for _, network := range networks { + desc := portDescriptor{network, transport, port} + m, ok := s.allocatedPorts[desc] + if !ok { + m = make(bindAddresses) + s.allocatedPorts[desc] = m + } + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + + n := p[dst] + if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 { + // Tuple already exists. + undo = true + } + n.AddRef(flagBits) + p[dst] = n + d[bindToDevice] = p + } + + if undo { + // releasePortLocked decrements the counts (rather than setting + // them to zero), so it will undo the incorrect incrementing + // above. + s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst) + return false } return true @@ -263,10 +508,14 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) { s.mu.Lock() defer s.mu.Unlock() + s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest)) +} + +func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) { for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -274,21 +523,32 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n, ok := d[bindToDevice] + p, ok := d[bindToDevice] if !ok { continue } - n.refs-- - d[bindToDevice] = n - if n.refs == 0 { - delete(d, bindToDevice) + n, ok := p[dst] + if !ok { + continue } - if len(d) == 0 { - delete(m, addr) + n.DropRef(flags) + if n.TotalRefs() > 0 { + p[dst] = n + continue + } + delete(p, dst) + if len(p) > 0 { + continue } - if len(m) == 0 { - delete(s.allocatedPorts, desc) + delete(d, bindToDevice) + if len(d) > 0 { + continue + } + delete(m, addr) + if len(m) > 0 { + continue } + delete(s.allocatedPorts, desc) } } } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 19f4833fc..58db5868c 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -33,9 +33,10 @@ type portReserveTestAction struct { port uint16 ip tcpip.Address want *tcpip.Error - reuse bool + flags Flags release bool device tcpip.NICID + dest tcpip.FullAddress } func TestPortReservation(t *testing.T) { @@ -50,7 +51,7 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, }, }, { @@ -61,7 +62,7 @@ func TestPortReservation(t *testing.T) { /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -71,36 +72,36 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 00, ip: fakeIPAddress, want: nil}, {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to ip with reuseport", actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: true, want: nil}, + {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to inaddr any with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, release: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, }, }, { tname: "bind twice with device fails", @@ -125,88 +126,200 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind with device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, want: nil}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { - tname: "bind with reuse", + tname: "bind with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { - tname: "binding with reuse and device", + tname: "binding with reuseport and device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, }, }, { - tname: "mixing reuse and not reuse by binding to device", + tname: "mixing reuseport and not reuseport by binding to device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: nil}, }, }, { - tname: "can't bind to 0 after mixing reuse and not reuse", + tname: "can't bind to 0 after mixing reuseport and not reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind and release", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, }, }, { - tname: "bind twice with reuse once", + tname: "bind twice with reuseport once", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "release an unreserved device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, // Release all. - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, + }, + }, { + tname: "bind with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind twice with reuseaddr once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseport, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind two tuples with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind two tuples", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind wildcard, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard twice with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, }, }, } { @@ -216,19 +329,18 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) } } }) - } } diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index a57752a7c..cf0a5fefe 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -1,10 +1,11 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) go_binary( name = "tun_tcp_connect", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 2239c1e66..0ab089208 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -164,7 +164,7 @@ func main() { // Create TCP endpoint. var wq waiter.Queue ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { + if e != nil { log.Fatal(e) } diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index dad8ef399..43264b76d 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -1,10 +1,11 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) go_binary( name = "tun_tcp_echo", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/link/fdbased", diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index bca73cbb1..9e37cab18 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -168,7 +168,7 @@ func main() { // Create TCP endpoint, bind it, then start listening. var wq waiter.Queue ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) - if err != nil { + if e != nil { log.Fatal(e) } diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index 29b7d761c..45f503845 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -1,12 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "seqnum", srcs = ["seqnum.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go index b40a3c212..d3bea7de4 100644 --- a/pkg/tcpip/seqnum/seqnum.go +++ b/pkg/tcpip/seqnum/seqnum.go @@ -46,11 +46,6 @@ func (v Value) InWindow(first Value, size Size) bool { return v.InRange(first, first.Add(size)) } -// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y). -func Overlap(a Value, b Size, x Value, y Size) bool { - return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b)) -} - // Add calculates the sequence number following the [v, v+s) window. func (v Value) Add(s Size) Value { return v + Value(s) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 460db3cf8..900938dd1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -16,35 +15,88 @@ go_template_instance( }, ) +go_template_instance( + name = "neighbor_entry_list", + out = "neighbor_entry_list.go", + package = "stack", + prefix = "neighborEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*neighborEntry", + "Linker": "*neighborEntry", + }, +) + +go_template_instance( + name = "packet_buffer_list", + out = "packet_buffer_list.go", + package = "stack", + prefix = "PacketBuffer", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*PacketBuffer", + "Linker": "*PacketBuffer", + }, +) + +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ + "conntrack.go", + "dhcpv6configurationfromndpra_string.go", + "forwarder.go", + "headertype_string.go", "icmp_rate_limit.go", + "iptables.go", + "iptables_state.go", + "iptables_targets.go", + "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", "ndp.go", + "neighbor_cache.go", + "neighbor_entry.go", + "neighbor_entry_list.go", + "neighborstate_string.go", "nic.go", + "nud.go", + "packet_buffer.go", + "packet_buffer_list.go", + "rand.go", "registration.go", "route.go", "stack.go", "stack_global_state.go", + "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/ilist", + "//pkg/log", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", + "//pkg/tcpip/transport/tcpconntrack", "//pkg/waiter", "@org_golang_x_time//rate:go_default_library", ], @@ -52,46 +104,57 @@ go_library( go_test( name = "stack_x_test", - size = "small", + size = "medium", srcs = [ "ndp_test.go", + "nud_test.go", "stack_test.go", "transport_demuxer_test.go", "transport_test.go", ], + shard_count = 20, deps = [ ":stack", + "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/ports", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) go_test( name = "stack_test", size = "small", - srcs = ["linkaddrcache_test.go"], - embed = [":stack"], + srcs = [ + "fake_time_test.go", + "forwarder_test.go", + "linkaddrcache_test.go", + "neighbor_cache_test.go", + "neighbor_entry_test.go", + "nic_test.go", + "packet_buffer_test.go", + ], + library = ":stack", deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "@com_github_dpjacques_clockwork//:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) - -filegroup( - name = "autogen", - srcs = [ - "linkaddrentry_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go new file mode 100644 index 000000000..7dd344b4f --- /dev/null +++ b/pkg/tcpip/stack/conntrack.go @@ -0,0 +1,631 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" +) + +// Connection tracking is used to track and manipulate packets for NAT rules. +// The connection is created for a packet if it does not exist. Every +// connection contains two tuples (original and reply). The tuples are +// manipulated if there is a matching NAT rule. The packet is modified by +// looking at the tuples in the Prerouting and Output hooks. +// +// Currently, only TCP tracking is supported. + +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 + +// Direction of the tuple. +type direction int + +const ( + dirOriginal direction = iota + dirReply +) + +// Manipulation type for the connection. +type manipType int + +const ( + manipNone manipType = iota + manipDstPrerouting + manipDstOutput +) + +// tuple holds a connection's identifying and manipulating data in one +// direction. It is immutable. +// +// +stateify savable +type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry + + tupleID + + // conn is the connection tracking entry this tuple belongs to. + conn *conn + + // direction is the direction of the tuple. + direction direction +} + +// tupleID uniquely identifies a connection in one direction. It currently +// contains enough information to distinguish between any TCP or UDP +// connection, and will need to be extended to support other protocols. +// +// +stateify savable +type tupleID struct { + srcAddr tcpip.Address + srcPort uint16 + dstAddr tcpip.Address + dstPort uint16 + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber +} + +// reply creates the reply tupleID. +func (ti tupleID) reply() tupleID { + return tupleID{ + srcAddr: ti.dstAddr, + srcPort: ti.dstPort, + dstAddr: ti.srcAddr, + dstPort: ti.srcPort, + transProto: ti.transProto, + netProto: ti.netProto, + } +} + +// conn is a tracked connection. +// +// +stateify savable +type conn struct { + // original is the tuple in original direction. It is immutable. + original tuple + + // reply is the tuple in reply direction. It is immutable. + reply tuple + + // manip indicates if the packet should be manipulated. It is immutable. + manip manipType + + // tcbHook indicates if the packet is inbound or outbound to + // update the state of tcb. It is immutable. + tcbHook Hook + + // mu protects all mutable state. + mu sync.Mutex `state:"nosave"` + // tcb is TCB control block. It is used to keep track of states + // of tcp connection and is protected by mu. + tcb tcpconntrack.TCB + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` +} + +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout +} + +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } +} + +// ConnTrack tracks all connections created for NAT rules. Most users are +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable +type ConnTrack struct { + // seed is a one-time random value initialized at stack startup + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 + + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` + + // buckets is protected by mu. + buckets []bucket +} + +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList +} + +// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid +// TCP header. +func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { + // TODO(gvisor.dev/issue/170): Need to support for other + // protocols as well. + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { + return tupleID{}, tcpip.ErrUnknownProtocol + } + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { + return tupleID{}, tcpip.ErrUnknownProtocol + } + + return tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: tcpHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: tcpHeader.DestinationPort(), + transProto: netHeader.TransportProtocol(), + netProto: header.IPv4ProtocolNumber, + }, nil +} + +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn +} + +// connFor gets the conn for pkt if it exists, or returns nil +// if it does not. It returns an error when pkt does not contain a valid TCP +// header. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support +// other transport protocols. +func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil, dirOriginal + } + return ct.connForTID(tid) +} + +func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } + } + + return nil, dirOriginal +} + +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Prerouting && hook != Output { + return nil + } + + // Create a new connection and change the port as per the iptables + // rule. This tuple will be used to manipulate the packet in + // handlePacket. + replyTID := tid.reply() + replyTID.srcAddr = rt.MinIP + replyTID.srcPort = rt.MinPort + var manip manipType + switch hook { + case Prerouting: + manip = manipDstPrerouting + case Output: + manip = manipDstOutput + } + conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn +} + +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() + } + + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } + + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + } + + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } +} + +// handlePacketPrerouting manipulates ports for packets in Prerouting hook. +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. +func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) + + // For prerouting redirection, packets going in the original direction + // have their destinations modified and replies have their sources + // modified. + switch dir { + case dirOriginal: + port := conn.reply.srcPort + tcpHeader.SetDestinationPort(port) + netHeader.SetDestinationAddress(conn.reply.srcAddr) + case dirReply: + port := conn.original.dstPort + tcpHeader.SetSourcePort(port) + netHeader.SetSourceAddress(conn.original.dstAddr) + } + + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) +} + +// handlePacketOutput manipulates ports for packets in Output hook. +func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) + + // For output redirection, packets going in the original direction + // have their destinations modified and replies have their sources + // modified. For prerouting redirection, we only reach this point + // when replying, so packet sources are modified. + if conn.manip == manipDstOutput && dir == dirOriginal { + port := conn.reply.srcPort + tcpHeader.SetDestinationPort(port) + netHeader.SetDestinationAddress(conn.reply.srcAddr) + } else { + port := conn.original.dstPort + tcpHeader.SetSourcePort(port) + netHeader.SetSourceAddress(conn.original.dstAddr) + } + + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) + xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + if gso != nil && gso.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) +} + +// handlePacket will manipulate the port and address of the packet if the +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { + if pkt.NatDone { + return false + } + + if hook != Prerouting && hook != Output { + return false + } + + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + return false + } + + conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. + if conn == nil { + return true + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { + return false + } + + switch hook { + case Prerouting: + handlePacketPrerouting(pkt, conn, dir) + case Output: + handlePacketOutput(pkt, conn, gso, r, dir) + } + pkt.NatDone = true + + // Update the state of tcb. + // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle + // other tcp states. + conn.mu.Lock() + defer conn.mu.Unlock() + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return + } + + // We only track TCP connections. + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + return + } + + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { + return + } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + ct.insertConn(conn) +} + +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} + +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } + } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval + } + // We've hit the maximum interval. + return idx, maxInterval +} + +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: ct.mu is locked for reading and bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false + } + + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } + + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) + } + ct.buckets[bucket].tuples.Remove(tuple) + + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() + } + + return true +} + +func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + // Lookup the connection. The reply's original destination + // describes the original address. + tid := tupleID{ + srcAddr: epID.LocalAddress, + srcPort: epID.LocalPort, + dstAddr: epID.RemoteAddress, + dstPort: epID.RemotePort, + transProto: header.TCPProtocolNumber, + netProto: header.IPv4ProtocolNumber, + } + conn, _ := ct.connForTID(tid) + if conn == nil { + // Not a tracked connection. + return "", 0, tcpip.ErrNotConnected + } else if conn.manip == manipNone { + // Unmanipulated connection. + return "", 0, tcpip.ErrInvalidOptionValue + } + + return conn.original.dstAddr, conn.original.dstPort, nil +} diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go new file mode 100644 index 000000000..d199ded6a --- /dev/null +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[DHCPv6NoConfiguration-1] + _ = x[DHCPv6ManagedAddress-2] + _ = x[DHCPv6OtherConfigurations-3] +} + +const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" + +var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} + +func (i DHCPv6ConfigurationFromNDPRA) String() string { + i -= 1 + if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] +} diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go new file mode 100644 index 000000000..92c8cb534 --- /dev/null +++ b/pkg/tcpip/stack/fake_time_test.go @@ -0,0 +1,209 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "container/heap" + "sync" + "time" + + "github.com/dpjacques/clockwork" + "gvisor.dev/gvisor/pkg/tcpip" +) + +type fakeClock struct { + clock clockwork.FakeClock + + // mu protects the fields below. + mu sync.RWMutex + + // times is min-heap of times. A heap is used for quick retrieval of the next + // upcoming time of scheduled work. + times *timeHeap + + // waitGroups stores one WaitGroup for all work scheduled to execute at the + // same time via AfterFunc. This allows parallel execution of all functions + // passed to AfterFunc scheduled for the same time. + waitGroups map[time.Time]*sync.WaitGroup +} + +func newFakeClock() *fakeClock { + return &fakeClock{ + clock: clockwork.NewFakeClock(), + times: &timeHeap{}, + waitGroups: make(map[time.Time]*sync.WaitGroup), + } +} + +var _ tcpip.Clock = (*fakeClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (fc *fakeClock) NowNanoseconds() int64 { + return fc.clock.Now().UnixNano() +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (fc *fakeClock) NowMonotonic() int64 { + return fc.NowNanoseconds() +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { + until := fc.clock.Now().Add(d) + wg := fc.addWait(until) + return &fakeTimer{ + clock: fc, + until: until, + timer: fc.clock.AfterFunc(d, func() { + defer wg.Done() + f() + }), + } +} + +// addWait adds an additional wait to the WaitGroup for parallel execution of +// all work scheduled for t. Returns a reference to the WaitGroup modified. +func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup { + fc.mu.RLock() + wg, ok := fc.waitGroups[t] + fc.mu.RUnlock() + + if ok { + wg.Add(1) + return wg + } + + fc.mu.Lock() + heap.Push(fc.times, t) + fc.mu.Unlock() + + wg = &sync.WaitGroup{} + wg.Add(1) + + fc.mu.Lock() + fc.waitGroups[t] = wg + fc.mu.Unlock() + + return wg +} + +// removeWait removes a wait from the WaitGroup for parallel execution of all +// work scheduled for t. +func (fc *fakeClock) removeWait(t time.Time) { + fc.mu.RLock() + defer fc.mu.RUnlock() + + wg := fc.waitGroups[t] + wg.Done() +} + +// advance executes all work that have been scheduled to execute within d from +// the current fake time. Blocks until all work has completed execution. +func (fc *fakeClock) advance(d time.Duration) { + // Block until all the work is done + until := fc.clock.Now().Add(d) + for { + fc.mu.Lock() + if fc.times.Len() == 0 { + fc.mu.Unlock() + return + } + + t := heap.Pop(fc.times).(time.Time) + if t.After(until) { + // No work to do + heap.Push(fc.times, t) + fc.mu.Unlock() + return + } + fc.mu.Unlock() + + diff := t.Sub(fc.clock.Now()) + fc.clock.Advance(diff) + + fc.mu.RLock() + wg := fc.waitGroups[t] + fc.mu.RUnlock() + + wg.Wait() + + fc.mu.Lock() + delete(fc.waitGroups, t) + fc.mu.Unlock() + } +} + +type fakeTimer struct { + clock *fakeClock + timer clockwork.Timer + + mu sync.RWMutex + until time.Time +} + +var _ tcpip.Timer = (*fakeTimer)(nil) + +// Reset implements tcpip.Timer.Reset. +func (ft *fakeTimer) Reset(d time.Duration) { + if !ft.timer.Reset(d) { + return + } + + ft.mu.Lock() + defer ft.mu.Unlock() + + ft.clock.removeWait(ft.until) + ft.until = ft.clock.clock.Now().Add(d) + ft.clock.addWait(ft.until) +} + +// Stop implements tcpip.Timer.Stop. +func (ft *fakeTimer) Stop() bool { + if !ft.timer.Stop() { + return false + } + + ft.mu.RLock() + defer ft.mu.RUnlock() + + ft.clock.removeWait(ft.until) + return true +} + +type timeHeap []time.Time + +var _ heap.Interface = (*timeHeap)(nil) + +func (h timeHeap) Len() int { + return len(h) +} + +func (h timeHeap) Less(i, j int) bool { + return h[i].Before(h[j]) +} + +func (h timeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *timeHeap) Push(x interface{}) { + *h = append(*h, x.(time.Time)) +} + +func (h *timeHeap) Pop() interface{} { + last := (*h)[len(*h)-1] + *h = (*h)[:len(*h)-1] + return last +} diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go new file mode 100644 index 000000000..3eff141e6 --- /dev/null +++ b/pkg/tcpip/stack/forwarder.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // maxPendingResolutions is the maximum number of pending link-address + // resolutions. + maxPendingResolutions = 64 + maxPendingPacketsPerResolution = 256 +) + +type pendingPacket struct { + nic *NIC + route *Route + proto tcpip.NetworkProtocolNumber + pkt *PacketBuffer +} + +type forwardQueue struct { + sync.Mutex + + // The packets to send once the resolver completes. + packets map[<-chan struct{}][]*pendingPacket + + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + cancelChans []chan struct{} +} + +func newForwardQueue() *forwardQueue { + return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} +} + +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + shouldWait := false + + f.Lock() + packets, ok := f.packets[ch] + if !ok { + shouldWait = true + } + for len(packets) == maxPendingPacketsPerResolution { + p := packets[0] + packets = packets[1:] + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Release() + } + if l := len(packets); l >= maxPendingPacketsPerResolution { + panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) + } + f.packets[ch] = append(packets, &pendingPacket{ + nic: n, + route: r, + proto: protocol, + pkt: pkt, + }) + f.Unlock() + + if !shouldWait { + return + } + + // Wait for the link-address resolution to complete. + // Start a goroutine with a forwarding-cancel channel so that we can + // limit the maximum number of goroutines running concurrently. + cancel := f.newCancelChannel() + go func() { + cancelled := false + select { + case <-ch: + case <-cancel: + cancelled = true + } + + f.Lock() + packets := f.packets[ch] + delete(f.packets, ch) + f.Unlock() + + for _, p := range packets { + if cancelled { + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + } else if _, err := p.route.Resolve(nil); err != nil { + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + } else { + p.nic.forwardPacket(p.route, p.proto, p.pkt) + } + p.route.Release() + } + }() +} + +// newCancelChannel creates a channel that can cancel a pending forwarding +// activity. The oldest channel is closed if the number of open channels would +// exceed maxPendingResolutions. +func (f *forwardQueue) newCancelChannel() chan struct{} { + f.Lock() + defer f.Unlock() + + if len(f.cancelChans) == maxPendingResolutions { + ch := f.cancelChans[0] + f.cancelChans = f.cancelChans[1:] + close(ch) + } + if l := len(f.cancelChans); l >= maxPendingResolutions { + panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) + } + + ch := make(chan struct{}) + f.cancelChans = append(f.cancelChans, ch) + return ch +} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go new file mode 100644 index 000000000..9dff23623 --- /dev/null +++ b/pkg/tcpip/stack/forwarder_test.go @@ -0,0 +1,648 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + fwdTestNetHeaderLen = 12 + fwdTestNetDefaultPrefixLen = 8 + + // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match + // the MTU of loopback interfaces on linux systems. + fwdTestNetDefaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 +) + +// fwdTestNetworkEndpoint is a network-layer protocol endpoint. +// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only +// use the first three: destination address, source address, and transport +// protocol. They're all one byte fields to simplify parsing. +type fwdTestNetworkEndpoint struct { + nicID tcpip.NICID + proto *fwdTestNetworkProtocol + dispatcher TransportDispatcher + ep LinkEndpoint +} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) +} + +func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { + return f.nicID +} + +func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { + return 123 +} + +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) +} + +func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { + return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { + return 0 +} + +func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { + return f.ep.Capabilities() +} + +func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return f.proto.Number() +} + +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { + // Add the protocol's header to the packet and send it to the link + // endpoint. + b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) + b[dstAddrOffset] = r.RemoteAddress[0] + b[srcAddrOffset] = r.LocalAddress[0] + b[protocolNumberOffset] = byte(params.Protocol) + + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { + panic("not implemented") +} + +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func (*fwdTestNetworkEndpoint) Close() {} + +// fwdTestNetworkProtocol is a network-layer protocol that implements Address +// resolution. +type fwdTestNetworkProtocol struct { + addrCache *linkAddrCache + addrResolveDelay time.Duration + onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) + onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) +} + +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) + +func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { + return fakeNetNumber +} + +func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { + return fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { + return fwdTestNetDefaultPrefixLen +} + +func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) +} + +func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) + if !ok { + return 0, false, false + } + return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true +} + +func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { + return &fwdTestNetworkEndpoint{ + nicID: nicID, + proto: f, + dispatcher: dispatcher, + ep: ep, + } +} + +func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (f *fwdTestNetworkProtocol) Close() {} + +func (f *fwdTestNetworkProtocol) Wait() {} + +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + if f.addrCache != nil && f.onLinkAddressResolved != nil { + time.AfterFunc(f.addrResolveDelay, func() { + f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr) + }) + } + return nil +} + +func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if f.onResolveStaticAddress != nil { + return f.onResolveStaticAddress(addr) + } + return "", false +} + +func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return fakeNetNumber +} + +// fwdTestPacketInfo holds all the information about an outbound packet. +type fwdTestPacketInfo struct { + RemoteLinkAddress tcpip.LinkAddress + LocalLinkAddress tcpip.LinkAddress + Pkt *PacketBuffer +} + +type fwdTestLinkEndpoint struct { + dispatcher NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + + // C is where outbound packets are queued. + C chan fwdTestPacketInfo +} + +// InjectInbound injects an inbound packet. +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) +} + +// InjectLinkAddr injects an inbound packet with a remote link address. +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *fwdTestLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *fwdTestLinkEndpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { + caps := LinkEndpointCapabilities(0) + return caps | CapabilityResolutionRequired +} + +// GSOMaxSize returns the maximum GSO packet size. +func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { + return 1 << 15 +} + +// MaxHeaderLength returns the maximum size of the link layer header. Given it +// doesn't have a header, it just returns 0. +func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + p := fwdTestPacketInfo{ + RemoteLinkAddress: r.RemoteLinkAddress, + LocalLinkAddress: r.LocalLinkAddress, + Pkt: pkt, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// WritePackets stores outbound packets into the channel. +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.WritePacket(r, gso, protocol, pkt) + n++ + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + p := fwdTestPacketInfo{ + Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*fwdTestLinkEndpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { + // Create a stack with the network protocol and two NICs. + s := New(Options{ + NetworkProtocols: []NetworkProtocol{proto}, + }) + + proto.addrCache = s.linkAddrCache + + // Enable forwarding. + s.SetForwarding(proto.Number(), true) + + // NIC 1 has the link address "a", and added the network address 1. + ep1 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "a", + } + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC #1 failed:", err) + } + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatal("AddAddress #1 failed:", err) + } + + // NIC 2 has the link address "b", and added the network address 2. + ep2 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "b", + } + if err := s.CreateNIC(2, ep2); err != nil { + t.Fatal("CreateNIC #2 failed:", err) + } + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + t.Fatal("AddAddress #2 failed:", err) + } + + // Route all packets to NIC 2. + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) + } + + return ep1, ep2 +} + +func TestForwardingWithStaticResolver(t *testing.T) { + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } + + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithFakeResolver(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any address will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithNoResolver(t *testing.T) { + // Create a network protocol without a resolver. + proto := &fwdTestNetworkProtocol{} + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + select { + case <-ep2.C: + t.Fatal("Packet should not be forwarded") + case <-time.After(time.Second): + } +} + +func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + } + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} + +func TestForwardingWithFakeResolverManyPackets(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) + } + seqNumBuf := b[fwdTestNetHeaderLen:] + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} + +func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go new file mode 100644 index 000000000..5efddfaaf --- /dev/null +++ b/pkg/tcpip/stack/headertype_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type headerType ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[linkHeader-0] + _ = x[networkHeader-1] + _ = x[transportHeader-2] + _ = x[numHeaderType-3] +} + +const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType" + +var _headerType_index = [...]uint8{0, 10, 23, 38, 51} + +func (i headerType) String() string { + if i < 0 || i >= headerType(len(_headerType_index)-1) { + return "headerType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _headerType_name[_headerType_index[i]:_headerType_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go new file mode 100644 index 000000000..c37da814f --- /dev/null +++ b/pkg/tcpip/stack/iptables.go @@ -0,0 +1,423 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// tableID is an index into IPTables.tables. +type tableID int + +const ( + natID tableID = iota + mangleID + filterID + numTables +) + +// Table names. +const ( + NATTable = "nat" + MangleTable = "mangle" + FilterTable = "filter" +) + +// nameToID is immutable. +var nameToID = map[string]tableID{ + NATTable: natID, + MangleTable: mangleID, + FilterTable: filterID, +} + +// HookUnset indicates that there is no hook set for an entrypoint or +// underflow. +const HookUnset = -1 + +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + +// DefaultTables returns a default set of tables. Each chain is set to accept +// all packets. +func DefaultTables() *IPTables { + return &IPTables{ + tables: [numTables]Table{ + natID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + }, + mangleID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, + }, + }, + filterID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + }, + }, + priorities: [NumHooks][]tableID{ + Prerouting: []tableID{mangleID, natID}, + Input: []tableID{natID, filterID}, + Output: []tableID{mangleID, natID, filterID}, + }, + connections: ConnTrack{ + seed: generateRandUint32(), + }, + reaperDone: make(chan struct{}, 1), + } +} + +// EmptyFilterTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyFilterTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, + }, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, + }, + } +} + +// EmptyNATTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyNATTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: [NumHooks]int{ + Forward: HookUnset, + }, + Underflows: [NumHooks]int{ + Forward: HookUnset, + }, + } +} + +// GetTable returns a table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + id, ok := nameToID[name] + if !ok { + return Table{}, false + } + it.mu.RLock() + defer it.mu.RUnlock() + return it.tables[id], true +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { + id, ok := nameToID[name] + if !ok { + return tcpip.ErrInvalidOptionValue + } + it.mu.Lock() + defer it.mu.Unlock() + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } + it.modified = true + it.tables[id] = table + return nil +} + +// A chainVerdict is what a table decides should be done with a packet. +type chainVerdict int + +const ( + // chainAccept indicates the packet should continue through netstack. + chainAccept chainVerdict = iota + + // chainAccept indicates the packet should be dropped. + chainDrop + + // chainReturn indicates the packet should return to the calling chain + // or the underflow rule of a builtin chain. + chainReturn +) + +// Check runs pkt through the rules for hook. It returns true when the packet +// should continue traversing the network stack and false when it should be +// dropped. +// +// Precondition: pkt.NetworkHeader is set. +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { + // Many users never configure iptables. Spare them the cost of rule + // traversal if rules have never been set. + it.mu.RLock() + defer it.mu.RUnlock() + if !it.modified { + return true + } + + // Packets are manipulated only if connection and matching + // NAT rule exists. + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) + + // Go through each table containing the hook. + priorities := it.priorities[hook] + for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } + table := it.tables[tableID] + ruleIdx := table.BuiltinChains[hook] + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + // If the table returns Accept, move on to the next table. + case chainAccept: + continue + // The Drop verdict is final. + case chainDrop: + return false + case chainReturn: + // Any Return from a built-in chain means we have to + // call the underflow. + underflow := table.Rules[table.Underflows[hook]] + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { + case RuleAccept: + continue + case RuleDrop: + return false + case RuleJump, RuleReturn: + panic("Underflows should only return RuleAccept or RuleDrop.") + default: + panic(fmt.Sprintf("Unknown verdict: %d", v)) + } + + default: + panic(fmt.Sprintf("Unknown verdict %v.", verdict)) + } + } + + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + + // Every table returned Accept. + return true +} + +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + +// CheckPackets runs pkts through the rules for hook and returns a map of packets that +// should not go forward. +// +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. +// +// NOTE: unlike the Check API the returned map contains packets that should be +// dropped. +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if !pkt.NatDone { + if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + if pkt.NatDone { + if natPkts == nil { + natPkts = make(map[*PacketBuffer]struct{}) + } + natPkts[pkt] = struct{}{} + } + } + } + return drop, natPkts +} + +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { + // Start from ruleIdx and walk the list of rules until a rule gives us + // a verdict. + for ruleIdx < len(table.Rules) { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + case RuleAccept: + return chainAccept + + case RuleDrop: + return chainDrop + + case RuleReturn: + return chainReturn + + case RuleJump: + // "Jumping" to the next rule just means we're + // continuing on down the list. + if jumpTo == ruleIdx+1 { + ruleIdx++ + continue + } + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { + case chainAccept: + return chainAccept + case chainDrop: + return chainDrop + case chainReturn: + ruleIdx++ + continue + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + } + + // We got through the entire table without a decision. Default to DROP + // for safety. + return chainDrop +} + +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { + rule := table.Rules[ruleIdx] + + // Check whether the packet matches the IP header filter. + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + + // Go through each rule matcher. If they all match, run + // the rule target. + for _, matcher := range rule.Matchers { + matches, hotdrop := matcher.Match(hook, pkt, "") + if hotdrop { + return RuleDrop, 0 + } + if !matches { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + } + + // All the matchers matched, so run the target. + return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) +} + +// OriginalDst returns the original destination of redirected connections. It +// returns an error if the connection doesn't exist or isn't redirected. +func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + return it.connections.originalDst(epID) +} diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go new file mode 100644 index 000000000..5f1b2af64 --- /dev/null +++ b/pkg/tcpip/stack/iptables_targets.go @@ -0,0 +1,163 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// AcceptTarget accepts packets. +type AcceptTarget struct{} + +// Action implements Target.Action. +func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { + return RuleAccept, 0 +} + +// DropTarget drops packets. +type DropTarget struct{} + +// Action implements Target.Action. +func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { + return RuleDrop, 0 +} + +// ErrorTarget logs an error and drops the packet. It represents a target that +// should be unreachable. +type ErrorTarget struct{} + +// Action implements Target.Action. +func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { + log.Debugf("ErrorTarget triggered.") + return RuleDrop, 0 +} + +// UserChainTarget marks a rule as the beginning of a user chain. +type UserChainTarget struct { + Name string +} + +// Action implements Target.Action. +func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { + panic("UserChainTarget should never be called.") +} + +// ReturnTarget returns from the current chain. If the chain is a built-in, the +// hook's underflow should be called. +type ReturnTarget struct{} + +// Action implements Target.Action. +func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { + return RuleReturn, 0 +} + +// RedirectTarget redirects the packet by modifying the destination port/IP. +// Min and Max values for IP and Ports in the struct indicate the range of +// values which can be used to redirect. +type RedirectTarget struct { + // TODO(gvisor.dev/issue/170): Other flags need to be added after + // we support them. + // RangeProtoSpecified flag indicates single port is specified to + // redirect. + RangeProtoSpecified bool + + // MinIP indicates address used to redirect. + MinIP tcpip.Address + + // MaxIP indicates address used to redirect. + MaxIP tcpip.Address + + // MinPort indicates port used to redirect. + MinPort uint16 + + // MaxPort indicates port used to redirect. + MaxPort uint16 +} + +// Action implements Target.Action. +// TODO(gvisor.dev/issue/170): Parse headers without copying. The current +// implementation only works for PREROUTING and calls pkt.Clone(), neither +// of which should be the case. +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Packet is already manipulated. + if pkt.NatDone { + return RuleAccept, 0 + } + + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { + return RuleDrop, 0 + } + + // Change the address to localhost (127.0.0.1) in Output and + // to primary address of the incoming interface in Prerouting. + switch hook { + case Output: + rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1}) + rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1}) + case Prerouting: + rt.MinIP = address + rt.MaxIP = address + default: + panic("redirect target is supported only on output and prerouting hooks") + } + + // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if + // we need to change dest address (for OUTPUT chain) or ports. + netHeader := header.IPv4(pkt.NetworkHeader().View()) + switch protocol := netHeader.TransportProtocol(); protocol { + case header.UDPProtocolNumber: + udpHeader := header.UDP(pkt.TransportHeader().View()) + udpHeader.SetDestinationPort(rt.MinPort) + + // Calculate UDP checksum and set it. + if hook == Output { + udpHeader.SetChecksum(0) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + xsum := r.PseudoHeaderChecksum(protocol, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHeader.SetChecksum(0) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) + } + } + // Change destination address. + netHeader.SetDestinationAddress(rt.MinIP) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + pkt.NatDone = true + case header.TCPProtocolNumber: + if ct == nil { + return RuleAccept, 0 + } + + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + ct.handlePacket(pkt, hook, gso, r) + } + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go new file mode 100644 index 000000000..73274ada9 --- /dev/null +++ b/pkg/tcpip/stack/iptables_types.go @@ -0,0 +1,262 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "strings" + "sync" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// A Hook specifies one of the hooks built into the network stack. +// +// Userspace app Userspace app +// ^ | +// | v +// [Input] [Output] +// ^ | +// | v +// | routing +// | | +// | v +// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> +type Hook uint + +// These values correspond to values in include/uapi/linux/netfilter.h. +const ( + // Prerouting happens before a packet is routed to applications or to + // be forwarded. + Prerouting Hook = iota + + // Input happens before a packet reaches an application. + Input + + // Forward happens once it's decided that a packet should be forwarded + // to another host. + Forward + + // Output happens after a packet is written by an application to be + // sent out. + Output + + // Postrouting happens just before a packet goes out on the wire. + Postrouting + + // The total number of hooks. + NumHooks +) + +// A RuleVerdict is what a rule decides should be done with a packet. +type RuleVerdict int + +const ( + // RuleAccept indicates the packet should continue through netstack. + RuleAccept RuleVerdict = iota + + // RuleDrop indicates the packet should be dropped. + RuleDrop + + // RuleJump indicates the packet should jump to another chain. + RuleJump + + // RuleReturn indicates the packet should return to the previous chain. + RuleReturn +) + +// IPTables holds all the tables for a netstack. +// +// +stateify savable +type IPTables struct { + // mu protects tables, priorities, and modified. + mu sync.RWMutex + + // tables maps tableIDs to tables. Holds builtin tables only, not user + // tables. mu must be locked for accessing. + tables [numTables]Table + + // priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. mu needs to be locked for accessing. + priorities [NumHooks][]tableID + + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + modified bool + + connections ConnTrack + + // reaperDone can be signalled to stop the reaper goroutine. + reaperDone chan struct{} +} + +// A Table defines a set of chains and hooks into the network stack. It is +// really just a list of rules. +// +// +stateify savable +type Table struct { + // Rules holds the rules that make up the table. + Rules []Rule + + // BuiltinChains maps builtin chains to their entrypoint rule in Rules. + BuiltinChains [NumHooks]int + + // Underflows maps builtin chains to their underflow rule in Rules + // (i.e. the rule to execute if the chain returns without a verdict). + Underflows [NumHooks]int +} + +// ValidHooks returns a bitmap of the builtin hooks for the given table. +func (table *Table) ValidHooks() uint32 { + hooks := uint32(0) + for hook, ruleIdx := range table.BuiltinChains { + if ruleIdx != HookUnset { + hooks |= 1 << hook + } + } + return hooks +} + +// A Rule is a packet processing rule. It consists of two pieces. First it +// contains zero or more matchers, each of which is a specification of which +// packets this rule applies to. If there are no matchers in the rule, it +// applies to any packet. +// +// +stateify savable +type Rule struct { + // Filter holds basic IP filtering fields common to every rule. + Filter IPHeaderFilter + + // Matchers is the list of matchers for this rule. + Matchers []Matcher + + // Target is the action to invoke if all the matchers match the packet. + Target Target +} + +// IPHeaderFilter holds basic IP filtering data common to every rule. +// +// +stateify savable +type IPHeaderFilter struct { + // Protocol matches the transport protocol. + Protocol tcpip.TransportProtocolNumber + + // Dst matches the destination IP address. + Dst tcpip.Address + + // DstMask masks bits of the destination IP address when comparing with + // Dst. + DstMask tcpip.Address + + // DstInvert inverts the meaning of the destination IP check, i.e. when + // true the filter will match packets that fail the destination + // comparison. + DstInvert bool + + // Src matches the source IP address. + Src tcpip.Address + + // SrcMask masks bits of the source IP address when comparing with Src. + SrcMask tcpip.Address + + // SrcInvert inverts the meaning of the source IP check, i.e. when true the + // filter will match packets that fail the source comparison. + SrcInvert bool + + // OutputInterface matches the name of the outgoing interface for the + // packet. + OutputInterface string + + // OutputInterfaceMask masks the characters of the interface name when + // comparing with OutputInterface. + OutputInterfaceMask string + + // OutputInterfaceInvert inverts the meaning of outgoing interface check, + // i.e. when true the filter will match packets that fail the outgoing + // interface comparison. + OutputInterfaceInvert bool +} + +// match returns whether hdr matches the filter. +func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the source and destination IPs. + if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + return false + } + + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(fl.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := fl.OutputInterface + matches := true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return fl.OutputInterfaceInvert != matches + } + + return true +} + +// filterAddress returns whether addr matches the filter. +func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { + matches := true + for i := range filterAddr { + if addr[i]&mask[i] != filterAddr[i] { + matches = false + break + } + } + return matches != invert +} + +// A Matcher is the interface for matching packets. +type Matcher interface { + // Name returns the name of the Matcher. + Name() string + + // Match returns whether the packet matches and whether the packet + // should be "hotdropped", i.e. dropped immediately. This is usually + // used for suspicious packets. + // + // Precondition: packet.NetworkHeader is set. + Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) +} + +// A Target is the interface for taking an action for a packet. +type Target interface { + // Action takes an action on the packet and returns a verdict on how + // traversal should (or should not) continue. If the return value is + // Jump, it also returns the index of the rule to jump to. + Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) +} diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 267df60d1..6f73a0ce4 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -16,10 +16,10 @@ package stack import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 9946b8fe8..b15b8d1cb 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,12 +16,12 @@ package stack import ( "fmt" - "sync" "sync/atomic" "testing" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) if f := r.onLinkAddressRequest; f != nil { f() diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 03ddebdbd..97ca00d16 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "log" + "math/rand" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -32,39 +33,277 @@ const ( // Default = 1 (from RFC 4862 section 5.1) defaultDupAddrDetectTransmits = 1 - // defaultRetransmitTimer is the default amount of time to wait between - // sending NDP Neighbor solicitation messages. + // defaultMaxRtrSolicitations is the default number of Router + // Solicitation messages to send when a NIC becomes enabled. // - // Default = 1s (from RFC 4861 section 10). - defaultRetransmitTimer = time.Second + // Default = 3 (from RFC 4861 section 10). + defaultMaxRtrSolicitations = 3 - // minimumRetransmitTimer is the minimum amount of time to wait between - // sending NDP Neighbor solicitation messages. Note, RFC 4861 does - // not impose a minimum Retransmit Timer, but we do here to make sure - // the messages are not sent all at once. We also come to this value - // because in the RetransmitTimer field of a Router Advertisement, a - // value of 0 means unspecified, so the smallest valid value is 1. - // Note, the unit of the RetransmitTimer field in the Router - // Advertisement is milliseconds. + // defaultRtrSolicitationInterval is the default amount of time between + // sending Router Solicitation messages. // - // Min = 1ms. - minimumRetransmitTimer = time.Millisecond + // Default = 4s (from 4861 section 10). + defaultRtrSolicitationInterval = 4 * time.Second + + // defaultMaxRtrSolicitationDelay is the default maximum amount of time + // to wait before sending the first Router Solicitation message. + // + // Default = 1s (from 4861 section 10). + defaultMaxRtrSolicitationDelay = time.Second + + // defaultHandleRAs is the default configuration for whether or not to + // handle incoming Router Advertisements as a host. + defaultHandleRAs = true + + // defaultDiscoverDefaultRouters is the default configuration for + // whether or not to discover default routers from incoming Router + // Advertisements, as a host. + defaultDiscoverDefaultRouters = true + + // defaultDiscoverOnLinkPrefixes is the default configuration for + // whether or not to discover on-link prefixes from incoming Router + // Advertisements' Prefix Information option, as a host. + defaultDiscoverOnLinkPrefixes = true + + // defaultAutoGenGlobalAddresses is the default configuration for + // whether or not to generate global IPv6 addresses in response to + // receiving a new Prefix Information option with its Autonomous + // Address AutoConfiguration flag set, as a host. + // + // Default = true. + defaultAutoGenGlobalAddresses = true + + // minimumRtrSolicitationInterval is the minimum amount of time to wait + // between sending Router Solicitation messages. This limit is imposed + // to make sure that Router Solicitation messages are not sent all at + // once, defeating the purpose of sending the initial few messages. + minimumRtrSolicitationInterval = 500 * time.Millisecond + + // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait + // before sending the first Router Solicitation message. It is 0 because + // we cannot have a negative delay. + minimumMaxRtrSolicitationDelay = 0 + + // MaxDiscoveredDefaultRouters is the maximum number of discovered + // default routers. The stack should stop discovering new routers after + // discovering MaxDiscoveredDefaultRouters routers. + // + // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and + // SHOULD be more. + MaxDiscoveredDefaultRouters = 10 + + // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered + // on-link prefixes. The stack should stop discovering new on-link + // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link + // prefixes. + MaxDiscoveredOnLinkPrefixes = 10 + + // validPrefixLenForAutoGen is the expected prefix length that an + // address can be generated for. Must be 64 bits as the interface + // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so + // 128 - 64 = 64. + validPrefixLenForAutoGen = 64 + + // defaultAutoGenTempGlobalAddresses is the default configuration for whether + // or not to generate temporary SLAAC addresses. + defaultAutoGenTempGlobalAddresses = true + + // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 7 days (from RFC 4941 section 5). + defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour + + // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 1 day (from RFC 4941 section 5). + defaultMaxTempAddrPreferredLifetime = 24 * time.Hour + + // defaultRegenAdvanceDuration is the default duration before the deprecation + // of a temporary address when a new address will be generated. + // + // Default = 5s (from RFC 4941 section 5). + defaultRegenAdvanceDuration = 5 * time.Second + + // minRegenAdvanceDuration is the minimum duration before the deprecation + // of a temporary address when a new address will be generated. + minRegenAdvanceDuration = time.Duration(0) + + // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt + // SLAAC address regenerations in response to a NIC-local conflict. + maxSLAACAddrLocalRegenAttempts = 10 +) + +var ( + // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid + // Lifetime to update the valid lifetime of a generated address by + // SLAAC. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Min = 2hrs. + MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour + + // MaxDesyncFactor is the upper bound for the preferred lifetime's desync + // factor for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Must be greater than 0. + // + // Max = 10m (from RFC 4941 section 5). + MaxDesyncFactor = 10 * time.Minute + + // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the + // maximum preferred lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address will be preferred for at + // least 1hr if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour + + // MinMaxTempAddrValidLifetime is the minimum value allowed for the + // maximum valid lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address will be valid for at least + // 2hrs if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrValidLifetime = 2 * time.Hour +) + +// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an +// NDP Router Advertisement informed the Stack about. +type DHCPv6ConfigurationFromNDPRA int + +const ( + _ DHCPv6ConfigurationFromNDPRA = iota + + // DHCPv6NoConfiguration indicates that no configurations are available via + // DHCPv6. + DHCPv6NoConfiguration + + // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. + // + // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 + // will return all available configuration information. + DHCPv6ManagedAddress + + // DHCPv6OtherConfigurations indicates that other configuration information is + // available via DHCPv6. + // + // Other configurations are configurations other than addresses. Examples of + // other configurations are recursive DNS server list, DNS search lists and + // default gateway. + DHCPv6OtherConfigurations ) // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { // OnDuplicateAddressDetectionStatus will be called when the DAD process - // for an address (addr) on a NIC (with ID nicid) completes. resolved + // for an address (addr) on a NIC (with ID nicID) completes. resolved // will be set to true if DAD completed successfully (no duplicate addr // detected); false otherwise (addr was detected to be a duplicate on // the link the NIC is a part of, or it was stopped for some other // reason, such as the address being removed). If an error occured // during DAD, err will be set and resolved must be ignored. // - // This function is permitted to block indefinitely without interfering - // with the stack's operation. - OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + + // OnDefaultRouterDiscovered will be called when a new default router is + // discovered. Implementations must return true if the newly discovered + // router should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool + + // OnDefaultRouterInvalidated will be called when a discovered default + // router that was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) + + // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is + // discovered. Implementations must return true if the newly discovered + // on-link prefix should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool + + // OnOnLinkPrefixInvalidated will be called when a discovered on-link + // prefix that was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) + + // OnAutoGenAddress will be called when a new prefix with its + // autonomous address-configuration flag set has been received and SLAAC + // has been performed. Implementations may prevent the stack from + // assigning the address to the NIC by returning false. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + + // OnAutoGenAddressDeprecated will be called when an auto-generated + // address (as part of SLAAC) has been deprecated, but is still + // considered valid. Note, if an address is invalidated at the same + // time it is deprecated, the deprecation event MAY be omitted. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnAutoGenAddressInvalidated will be called when an auto-generated + // address (as part of SLAAC) has been invalidated. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnRecursiveDNSServerOption will be called when an NDP option with + // recursive DNS servers has been received. Note, addrs may contain + // link-local addresses. + // + // It is up to the caller to use the DNS Servers only for their valid + // lifetime. OnRecursiveDNSServerOption may be called for new or + // already known DNS servers. If called with known DNS servers, their + // valid lifetimes must be refreshed to lifetime (it may be increased, + // decreased, or completely invalidated when lifetime = 0). + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + + // OnDNSSearchListOption will be called when an NDP option with a DNS + // search list has been received. + // + // It is up to the caller to use the domain names in the search list + // for only their valid lifetime. OnDNSSearchListOption may be called + // with new or already known domain names. If called with known domain + // names, their valid lifetimes must be refreshed to lifetime (it may + // be increased, decreased or completely invalidated when lifetime = 0. + OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) + + // OnDHCPv6Configuration will be called with an updated configuration that is + // available via DHCPv6 for a specified NIC. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } // NDPConfigurations is the NDP configurations for the netstack. @@ -78,28 +317,124 @@ type NDPConfigurations struct { // The amount of time to wait between sending Neighbor solicitation // messages. // - // Must be greater than 0.5s. + // Must be greater than or equal to 1ms. RetransmitTimer time.Duration + + // The number of Router Solicitation messages to send when the NIC + // becomes enabled. + MaxRtrSolicitations uint8 + + // The amount of time between transmitting Router Solicitation messages. + // + // Must be greater than or equal to 0.5s. + RtrSolicitationInterval time.Duration + + // The maximum amount of time before transmitting the first Router + // Solicitation message. + // + // Must be greater than or equal to 0s. + MaxRtrSolicitationDelay time.Duration + + // HandleRAs determines whether or not Router Advertisements will be + // processed. + HandleRAs bool + + // DiscoverDefaultRouters determines whether or not default routers will + // be discovered from Router Advertisements. This configuration is + // ignored if HandleRAs is false. + DiscoverDefaultRouters bool + + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes + // will be discovered from Router Advertisements' Prefix Information + // option. This configuration is ignored if HandleRAs is false. + DiscoverOnLinkPrefixes bool + + // AutoGenGlobalAddresses determines whether or not global IPv6 + // addresses will be generated for a NIC in response to receiving a new + // Prefix Information option with its Autonomous Address + // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // + // Note, if an address was already generated for some unique prefix, as + // part of SLAAC, this option does not affect whether or not the + // lifetime(s) of the generated address changes; this option only + // affects the generation of new addresses as part of SLAAC. + AutoGenGlobalAddresses bool + + // AutoGenAddressConflictRetries determines how many times to attempt to retry + // generation of a permanent auto-generated address in response to DAD + // conflicts. + // + // If the method used to generate the address does not support creating + // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's + // MAC address), then no attempt will be made to resolve the conflict. + AutoGenAddressConflictRetries uint8 + + // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC + // addresses will be generated for a NIC as part of SLAAC privacy extensions, + // RFC 4941. + // + // Ignored if AutoGenGlobalAddresses is false. + AutoGenTempGlobalAddresses bool + + // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary + // SLAAC addresses. + MaxTempAddrValidLifetime time.Duration + + // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for + // temporary SLAAC addresses. + MaxTempAddrPreferredLifetime time.Duration + + // RegenAdvanceDuration is the duration before the deprecation of a temporary + // address when a new address will be generated. + RegenAdvanceDuration time.Duration } // DefaultNDPConfigurations returns an NDPConfigurations populated with // default values. func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ - DupAddrDetectTransmits: defaultDupAddrDetectTransmits, - RetransmitTimer: defaultRetransmitTimer, + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + MaxRtrSolicitations: defaultMaxRtrSolicitations, + RtrSolicitationInterval: defaultRtrSolicitationInterval, + MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses, + MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime, + RegenAdvanceDuration: defaultRegenAdvanceDuration, } } // validate modifies an NDPConfigurations with valid values. If invalid values // are present in c, the corresponding default values will be used instead. -// -// If RetransmitTimer is less than minimumRetransmitTimer, then a value of -// defaultRetransmitTimer will be used. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer } + + if c.RtrSolicitationInterval < minimumRtrSolicitationInterval { + c.RtrSolicitationInterval = defaultRtrSolicitationInterval + } + + if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay { + c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay + } + + if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime { + c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime + } + + if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime { + c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime + } + + if c.RegenAdvanceDuration < minRegenAdvanceDuration { + c.RegenAdvanceDuration = minRegenAdvanceDuration + } } // ndpState is the per-interface NDP state. @@ -112,13 +447,47 @@ type ndpState struct { // The DAD state to send the next NS message, or resolve the address. dad map[tcpip.Address]dadState + + // The default routers discovered through Router Advertisements. + defaultRouters map[tcpip.Address]defaultRouterState + + rtrSolicit struct { + // The timer used to send the next router solicitation message. + timer tcpip.Timer + + // Used to let the Router Solicitation timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the NIC this ndpState is associated with. MUST be set when the timer is + // set. + done *bool + } + + // The on-link prefixes discovered through Router Advertisements' Prefix + // Information option. + onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + + // The SLAAC prefixes discovered through Router Advertisements' Prefix + // Information option. + slaacPrefixes map[tcpip.Subnet]slaacPrefixState + + // The last learned DHCPv6 configuration from an NDP RA. + dhcpv6Configuration DHCPv6ConfigurationFromNDPRA + + // temporaryIIDHistory is the history value used to generate a new temporary + // IID. + temporaryIIDHistory [header.IIDSize]byte + + // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for + // temporary SLAAC addresses. + temporaryAddressDesyncFactor time.Duration } // dadState holds the Duplicate Address Detection timer and channel to signal // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer *time.Timer + timer tcpip.Timer // Used to let the DAD timer know that it has been stopped. // @@ -127,6 +496,102 @@ type dadState struct { done *bool } +// defaultRouterState holds data associated with a default router discovered by +// a Router Advertisement (RA). +type defaultRouterState struct { + // Job to invalidate the default router. + // + // Must not be nil. + invalidationJob *tcpip.Job +} + +// onLinkPrefixState holds data associated with an on-link prefix discovered by +// a Router Advertisement's Prefix Information option (PI) when the NDP +// configurations was configured to do so. +type onLinkPrefixState struct { + // Job to invalidate the on-link prefix. + // + // Must not be nil. + invalidationJob *tcpip.Job +} + +// tempSLAACAddrState holds state associated with a temporary SLAAC address. +type tempSLAACAddrState struct { + // Job to deprecate the temporary SLAAC address. + // + // Must not be nil. + deprecationJob *tcpip.Job + + // Job to invalidate the temporary SLAAC address. + // + // Must not be nil. + invalidationJob *tcpip.Job + + // Job to regenerate the temporary SLAAC address. + // + // Must not be nil. + regenJob *tcpip.Job + + createdAt time.Time + + // The address's endpoint. + // + // Must not be nil. + ref *referencedNetworkEndpoint + + // Has a new temporary SLAAC address already been regenerated? + regenerated bool +} + +// slaacPrefixState holds state associated with a SLAAC prefix. +type slaacPrefixState struct { + // Job to deprecate the prefix. + // + // Must not be nil. + deprecationJob *tcpip.Job + + // Job to invalidate the prefix. + // + // Must not be nil. + invalidationJob *tcpip.Job + + // Nonzero only when the address is not valid forever. + validUntil time.Time + + // Nonzero only when the address is not preferred forever. + preferredUntil time.Time + + // State associated with the stable address generated for the prefix. + stableAddr struct { + // The address's endpoint. + // + // May only be nil when the address is being (re-)generated. Otherwise, + // must not be nil as all SLAAC prefixes must have a stable address. + ref *referencedNetworkEndpoint + + // The number of times an address has been generated locally where the NIC + // already had the generated address. + localGenerationFailures uint8 + } + + // The temporary (short-lived) addresses generated for the SLAAC prefix. + tempAddrs map[tcpip.Address]tempSLAACAddrState + + // The next two fields are used by both stable and temporary addresses + // generated for a SLAAC prefix. This is safe as only 1 address will be + // in the generation and DAD process at any time. That is, no two addresses + // will be generated at the same time for a given SLAAC prefix. + + // The number of times an address has been generated and added to the NIC. + // + // Addresses may be regenerated in reseponse to a DAD conflicts. + generationAttempts uint8 + + // The maximum number of times to attempt regeneration of a SLAAC address + // in response to DAD conflicts. + maxGenerationAttempts uint8 +} + // startDuplicateAddressDetection performs Duplicate Address Detection. // // This function must only be called by IPv6 addresses that are currently @@ -139,87 +604,110 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return tcpip.ErrAddressFamilyNotSupported } - // Should not attempt to perform DAD on an address that is currently in - // the DAD process. + if ref.getKind() != permanentTentative { + // The endpoint should be marked as tentative since we are starting DAD. + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) + } + + // Should not attempt to perform DAD on an address that is currently in the + // DAD process. if _, ok := ndp.dad[addr]; ok { - // Should never happen because we should only ever call this - // function for newly created addresses. If we attemped to - // "add" an address that already existed, we would returned an - // error since we attempted to add a duplicate address, or its - // reference count would have been increased without doing the - // work that would have been done for an address that was brand - // new. See NIC.addPermanentAddressLocked. + // Should never happen because we should only ever call this function for + // newly created addresses. If we attemped to "add" an address that already + // existed, we would get an error since we attempted to add a duplicate + // address, or its reference count would have been increased without doing + // the work that would have been done for an address that was brand new. + // See NIC.addAddressLocked. panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) } remaining := ndp.configs.DupAddrDetectTransmits + if remaining == 0 { + ref.setKind(permanent) - { - done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) - if err != nil { - return err - } - if done { - return nil + // Consider DAD to have resolved even if no DAD messages were actually + // transmitted. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil) } - } - remaining-- + return nil + } var done bool - var timer *time.Timer - timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { - var d bool - var err *tcpip.Error + var timer tcpip.Timer + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the NIC's lock. This is effectively the same + // as starting a goroutine but we use a timer that fires immediately so we can + // reset it for the next DAD iteration. + timer = ndp.nic.stack.Clock().AfterFunc(0, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() - // doDadIteration does a single iteration of the DAD loop. - // - // Returns true if the integrator needs to be informed of DAD - // completing. - doDadIteration := func() bool { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD - // timer fired after another goroutine already - // obtained the NIC lock and stopped DAD before - // this function obtained the NIC lock. Simply - // return here and do nothing further. - return false - } + if done { + // If we reach this point, it means that the DAD timer fired after + // another goroutine already obtained the NIC lock and stopped DAD + // before this function obtained the NIC lock. Simply return here and do + // nothing further. + return + } - ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] - if !ok { - // This should never happen. - // We should have an endpoint for addr since we - // are still performing DAD on it. If the - // endpoint does not exist, but we are doing DAD - // on it, then we started DAD at some point, but - // forgot to stop it when the endpoint was - // deleted. - panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) - } + if ref.getKind() != permanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())) + } - d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) - if err != nil || d { - delete(ndp.dad, addr) + dadDone := remaining == 0 - if err != nil { - log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) - } + var err *tcpip.Error + if !dadDone { + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) - // Let the integrator know DAD has completed. - return true - } + // Do not hold the lock when sending packets which may be a long running + // task or may block link address resolution. We know this is safe + // because immediately after obtaining the lock again, we check if DAD + // has been stopped before doing any work with the NIC. Note, DAD would be + // stopped if the NIC was disabled or removed, or if the address was + // removed. + ndp.nic.mu.Unlock() + err = ndp.sendDADPacket(addr, ref) + ndp.nic.mu.Lock() + } + if done { + // If we reach this point, it means that DAD was stopped after we released + // the NIC's read lock and before we obtained the write lock. + return + } + + if dadDone { + // DAD has resolved. + ref.setKind(permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. remaining-- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) - return false + return + } + + // At this point we know that either DAD is done or we hit an error sending + // the last NDP NS. Either way, clean up addr's DAD state and let the + // integrator know DAD has completed. + delete(ndp.dad, addr) + + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) } - if doDadIteration() && ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + // If DAD resolved for a stable SLAAC address, attempt generation of a + // temporary SLAAC address. + if dadDone && ref.configType == slaac { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) } }) @@ -231,60 +719,58 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return nil } -// doDuplicateAddressDetection is called on every iteration of the timer, and -// when DAD starts. -// -// It handles resolving the address (if there are no more NS to send), or -// sending the next NS if there are more NS to send. -// -// This function must only be called by IPv6 addresses that are currently -// tentative. +// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns +// addr. // -// The NIC that ndp belongs to (n) MUST be locked. +// addr must be a tentative IPv6 address on ndp's NIC. // -// Returns true if DAD has resolved; false if DAD is still ongoing. -func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { - if ref.getKind() != permanentTentative { - // The endpoint should still be marked as tentative - // since we are still performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) - } +// The NIC ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { + snmc := header.SolicitedNodeAddr(addr) - if remaining == 0 { - // DAD has resolved. - ref.setKind(permanent) - return true, nil - } + r := makeRoute(header.IPv6ProtocolNumber, ref.address(), snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + defer r.Release() - // Send a new NS. - snmc := header.SolicitedNodeAddr(addr) - snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] - if !ok { - // This should never happen as if we have the - // address, we should have the solicited-node - // address. - panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) - } + // Route should resolve immediately since snmc is a multicast address so a + // remote link address can be calculated without a resolution process. + if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { + return err + } - // Use the unspecified address as the source address when performing - // DAD. - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) + } else if c != nil { + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) + } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmpData.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) ns.SetTargetAddress(addr) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}); err != nil { + if err := r.WritePacket(nil, + NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + TOS: DefaultTOS, + }, pkt, + ); err != nil { sent.Dropped.Increment() - return false, err + return err } sent.NeighborSolicit.Increment() - return false, nil + return nil } // stopDuplicateAddressDetection ends a running Duplicate Address Detection @@ -315,7 +801,1173 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndp.nic.stack.ndpDisp != nil { - go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + } +} + +// handleRA handles a Router Advertisement message that arrived on the NIC +// this ndp is for. Does nothing if the NIC is configured to not handle RAs. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + // Is the NIC configured to handle RAs at all? + // + // Currently, the stack does not determine router interface status on a + // per-interface basis; it is a stack-wide configuration, so we check + // stack's forwarding flag to determine if the NIC is a routing + // interface. + if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) { + return + } + + // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we + // only inform the dispatcher on configuration changes. We do nothing else + // with the information. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + var configuration DHCPv6ConfigurationFromNDPRA + switch { + case ra.ManagedAddrConfFlag(): + configuration = DHCPv6ManagedAddress + + case ra.OtherConfFlag(): + configuration = DHCPv6OtherConfigurations + + default: + configuration = DHCPv6NoConfiguration + } + + if ndp.dhcpv6Configuration != configuration { + ndp.dhcpv6Configuration = configuration + ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) + } + } + + // Is the NIC configured to discover default routers? + if ndp.configs.DiscoverDefaultRouters { + rtr, ok := ndp.defaultRouters[ip] + rl := ra.RouterLifetime() + switch { + case !ok && rl != 0: + // This is a new default router we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredDefaultRouters routers. + if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { + ndp.rememberDefaultRouter(ip, rl) + } + + case ok && rl != 0: + // This is an already discovered default router. Update + // the invalidation job. + rtr.invalidationJob.Cancel() + rtr.invalidationJob.Schedule(rl) + ndp.defaultRouters[ip] = rtr + + case ok && rl == 0: + // We know about the router but it is no longer to be + // used as a default router so invalidate it. + ndp.invalidateDefaultRouter(ip) + } + } + + // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter + // Discovery. + + // We know the options is valid as far as wire format is concerned since + // we got the Router Advertisement, as documented by this fn. Given this + // we do not check the iterator for errors on calls to Next. + it, _ := ra.Options().Iter(false) + for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { + switch opt := opt.(type) { + case header.NDPRecursiveDNSServer: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + addrs, _ := opt.Addresses() + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) + + case header.NDPDNSSearchList: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + domainNames, _ := opt.DomainNames() + ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime()) + + case header.NDPPrefixInformation: + prefix := opt.Subnet() + + // Is the prefix a link-local? + if header.IsV6LinkLocalAddress(prefix.ID()) { + // ...Yes, skip as per RFC 4861 section 6.3.4, + // and RFC 4862 section 5.5.3.b (for SLAAC). + continue + } + + // Is the Prefix Length 0? + if prefix.Prefix() == 0 { + // ...Yes, skip as this is an invalid prefix + // as all IPv6 addresses cannot be on-link. + continue + } + + if opt.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(opt) + } + + if opt.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(opt) + } + } + + // TODO(b/141556115): Do (MTU) Parameter Discovery. + } +} + +// invalidateDefaultRouter invalidates a discovered default router. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { + rtr, ok := ndp.defaultRouters[ip] + + // Is the router still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + rtr.invalidationJob.Cancel() + delete(ndp.defaultRouters, ip) + + // Let the integrator know a discovered default router is invalidated. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + } +} + +// rememberDefaultRouter remembers a newly discovered default router with IPv6 +// link-local address ip with lifetime rl. +// +// The router identified by ip MUST NOT already be known by the NIC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered a default router. + if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { + // Informed by the integrator to not remember the router, do + // nothing further. + return + } + + state := defaultRouterState{ + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + ndp.invalidateDefaultRouter(ip) + }), + } + + state.invalidationJob.Schedule(rl) + + ndp.defaultRouters[ip] = state +} + +// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 +// address with prefix prefix with lifetime l. +// +// The prefix identified by prefix MUST NOT already be known. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered an on-link prefix. + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { + // Informed by the integrator to not remember the prefix, do + // nothing further. + return + } + + state := onLinkPrefixState{ + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + ndp.invalidateOnLinkPrefix(prefix) + }), + } + + if l < header.NDPInfiniteLifetime { + state.invalidationJob.Schedule(l) + } + + ndp.onLinkPrefixes[prefix] = state +} + +// invalidateOnLinkPrefix invalidates a discovered on-link prefix. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { + s, ok := ndp.onLinkPrefixes[prefix] + + // Is the on-link prefix still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + s.invalidationJob.Cancel() + delete(ndp.onLinkPrefixes, prefix) + + // Let the integrator know a discovered on-link prefix is invalidated. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + } +} + +// handleOnLinkPrefixInformation handles a Prefix Information option with +// its on-link flag set, as per RFC 4861 section 6.3.4. +// +// handleOnLinkPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the on-link flag is set. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { + prefix := pi.Subnet() + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + + if !ok && vl == 0 { + // Don't know about this prefix but it has a zero valid + // lifetime, so just ignore. + return + } + + if !ok && vl != 0 { + // This is a new on-link prefix we are discovering + // + // Only remember it if we currently know about less than + // MaxDiscoveredOnLinkPrefixes on-link prefixes. + if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + return + } + + if ok && vl == 0 { + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + return + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // + // Update the invalidation job. + + prefixState.invalidationJob.Cancel() + + if vl < header.NDPInfiniteLifetime { + // Prefix is valid for a finite lifetime, schedule the job to execute after + // the new valid lifetime. + prefixState.invalidationJob.Schedule(vl) + } + + ndp.onLinkPrefixes[prefix] = prefixState +} + +// handleAutonomousPrefixInformation handles a Prefix Information option with +// its autonomous flag set, as per RFC 4862 section 5.5.3. +// +// handleAutonomousPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the autonomous flag is set. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { + vl := pi.ValidLifetime() + pl := pi.PreferredLifetime() + + // If the preferred lifetime is greater than the valid lifetime, + // silently ignore the Prefix Information option, as per RFC 4862 + // section 5.5.3.c. + if pl > vl { + return + } + + prefix := pi.Subnet() + + // Check if we already maintain SLAAC state for prefix. + if state, ok := ndp.slaacPrefixes[prefix]; ok { + // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. + ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl) + ndp.slaacPrefixes[prefix] = state + return + } + + // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC. + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + ndp.doSLAAC(prefix, pl, vl) +} + +// doSLAAC generates a new SLAAC address with the provided lifetimes +// for prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { + // If we do not already have an address for this prefix and the valid + // lifetime is 0, no need to do anything further, as per RFC 4862 + // section 5.5.3.d. + if vl == 0 { + return + } + + // Make sure the prefix is valid (as far as its length is concerned) to + // generate a valid IPv6 address from an interface identifier (IID), as + // per RFC 4862 sectiion 5.5.3.d. + if prefix.Prefix() != validPrefixLenForAutoGen { + return + } + + state := slaacPrefixState{ + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) + } + + ndp.deprecateSLAACAddress(state.stableAddr.ref) + }), + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) + } + + ndp.invalidateSLAACPrefix(prefix, state) + }), + tempAddrs: make(map[tcpip.Address]tempSLAACAddrState), + maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, + } + + now := time.Now() + + // The time an address is preferred until is needed to properly generate the + // address. + if pl < header.NDPInfiniteLifetime { + state.preferredUntil = now.Add(pl) + } + + if !ndp.generateSLAACAddr(prefix, &state) { + // We were unable to generate an address for the prefix, we do not nothing + // further as there is no reason to maintain state or jobs for a prefix we + // do not have an address for. + return + } + + // Setup the initial jobs to deprecate and invalidate prefix. + + if pl < header.NDPInfiniteLifetime && pl != 0 { + state.deprecationJob.Schedule(pl) + } + + if vl < header.NDPInfiniteLifetime { + state.invalidationJob.Schedule(vl) + state.validUntil = now.Add(vl) + } + + // If the address is assigned (DAD resolved), generate a temporary address. + if state.stableAddr.ref.getKind() == permanent { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) + } + + ndp.slaacPrefixes[prefix] = state +} + +// addSLAACAddr adds a SLAAC address to the NIC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return nil + } + + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) { + // Informed by the integrator not to add the address. + return nil + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) + if err != nil { + panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) + } + + return ref +} + +// generateSLAACAddr generates a SLAAC address for prefix. +// +// Returns true if an address was successfully generated. +// +// Panics if the prefix is not a SLAAC prefix or it already has an address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { + if r := state.stableAddr.ref; r != nil { + panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if state.generationAttempts == state.maxGenerationAttempts { + return false + } + + var generatedAddr tcpip.AddressWithPrefix + addrBytes := []byte(prefix.ID()) + + for i := 0; ; i++ { + // If we were unable to generate an address after the maximum SLAAC address + // local regeneration attempts, do nothing further. + if i == maxSLAACAddrLocalRegenAttempts { + return false + } + + dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures + if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addrBytes = header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + dadCounter, + oIID.SecretKey, + ) + } else if dadCounter == 0 { + // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if + // the DAD counter is non-zero, we cannot use this method. + // + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return false + } + + // Generate an address within prefix from the modified EUI-64 of ndp's + // NIC's Ethernet MAC address. + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } else { + // We have no way to regenerate an address in response to an address + // conflict when addresses are not generated with opaque IIDs. + return false + } + + generatedAddr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: validPrefixLenForAutoGen, + } + + if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + break + } + + state.stableAddr.localGenerationFailures++ + } + + if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { + state.stableAddr.ref = ref + state.generationAttempts++ + return true + } + + return false +} + +// regenerateSLAACAddr regenerates an address for a SLAAC prefix. +// +// If generating a new address for the prefix fails, the prefix will be +// invalidated. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix)) + } + + if ndp.generateSLAACAddr(prefix, &state) { + ndp.slaacPrefixes[prefix] = state + return + } + + // We were unable to generate a permanent address for the SLAAC prefix so + // invalidate the prefix as there is no reason to maintain state for a + // SLAAC prefix we do not have an address for. + ndp.invalidateSLAACPrefix(prefix, state) +} + +// generateTempSLAACAddr generates a new temporary SLAAC address. +// +// If resetGenAttempts is true, the prefix's generation counter will be reset. +// +// Returns true if a new address was generated. +func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { + // Are we configured to auto-generate new temporary global addresses for the + // prefix? + if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() { + return false + } + + if resetGenAttempts { + prefixState.generationAttempts = 0 + prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1 + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if prefixState.generationAttempts == prefixState.maxGenerationAttempts { + return false + } + + stableAddr := prefixState.stableAddr.ref.address() + now := time.Now() + + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. + vl := ndp.configs.MaxTempAddrValidLifetime + if prefixState.validUntil != (time.Time{}) { + if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { + vl = prefixVL + } + } + + if vl <= 0 { + // Cannot create an address without a valid lifetime. + return false + } + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or the + // maximum temporary address preferred lifetime - the temporary address desync + // factor. + pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor + if prefixState.preferredUntil != (time.Time{}) { + if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { + // Respect the preferred lifetime of the prefix, as per RFC 4941 section + // 3.3 step 4. + pl = prefixPL + } + } + + // As per RFC 4941 section 3.3 step 5, a temporary address is created only if + // the calculated preferred lifetime is greater than the advance regeneration + // duration. In particular, we MUST NOT create a temporary address with a zero + // Preferred Lifetime. + if pl <= ndp.configs.RegenAdvanceDuration { + return false + } + + // Attempt to generate a new address that is not already assigned to the NIC. + var generatedAddr tcpip.AddressWithPrefix + for i := 0; ; i++ { + // If we were unable to generate an address after the maximum SLAAC address + // local regeneration attempts, do nothing further. + if i == maxSLAACAddrLocalRegenAttempts { + return false + } + + generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) + if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + break + } + } + + // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary + // address with a zero preferred lifetime. The checks above ensure this + // so we know the address is not deprecated. + ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) + if ref == nil { + return false + } + + state := tempSLAACAddrState{ + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) + } + + ndp.deprecateSLAACAddress(tempAddrState.ref) + }), + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr)) + } + + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) + }), + regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr)) + } + + // If an address has already been regenerated for this address, don't + // regenerate another address. + if tempAddrState.regenerated { + return + } + + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */) + prefixState.tempAddrs[generatedAddr.Address] = tempAddrState + ndp.slaacPrefixes[prefix] = prefixState + }), + createdAt: now, + ref: ref, + } + + state.deprecationJob.Schedule(pl) + state.invalidationJob.Schedule(vl) + state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) + + prefixState.generationAttempts++ + prefixState.tempAddrs[generatedAddr.Address] = state + + return true +} + +// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix)) + } + + ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts) + ndp.slaacPrefixes[prefix] = state +} + +// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { + // If the preferred lifetime is zero, then the prefix should be deprecated. + deprecated := pl == 0 + if deprecated { + ndp.deprecateSLAACAddress(prefixState.stableAddr.ref) + } else { + prefixState.stableAddr.ref.deprecated = false + } + + // If prefix was preferred for some finite lifetime before, cancel the + // deprecation job so it can be reset. + prefixState.deprecationJob.Cancel() + + now := time.Now() + + // Schedule the deprecation job if prefix has a finite preferred lifetime. + if pl < header.NDPInfiniteLifetime { + if !deprecated { + prefixState.deprecationJob.Schedule(pl) + } + prefixState.preferredUntil = now.Add(pl) + } else { + prefixState.preferredUntil = time.Time{} + } + + // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: + // + // 1) If the received Valid Lifetime is greater than 2 hours or greater than + // RemainingLifetime, set the valid lifetime of the prefix to the + // advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the + // advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. + + if vl >= header.NDPInfiniteLifetime { + // Handle the infinite valid lifetime separately as we do not schedule a + // job in this case. + prefixState.invalidationJob.Cancel() + prefixState.validUntil = time.Time{} + } else { + var effectiveVl time.Duration + var rl time.Duration + + // If the prefix was originally set to be valid forever, assume the + // remaining time to be the maximum possible value. + if prefixState.validUntil == (time.Time{}) { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(prefixState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl > MinPrefixInformationValidLifetimeForUpdate { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if effectiveVl != 0 { + prefixState.invalidationJob.Cancel() + prefixState.invalidationJob.Schedule(effectiveVl) + prefixState.validUntil = now.Add(effectiveVl) + } + } + + // If DAD is not yet complete on the stable address, there is no need to do + // work with temporary addresses. + if prefixState.stableAddr.ref.getKind() != permanent { + return + } + + // Note, we do not need to update the entries in the temporary address map + // after updating the jobs because the jobs are held as pointers. + var regenForAddr tcpip.Address + allAddressesRegenerated := true + for tempAddr, tempAddrState := range prefixState.tempAddrs { + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. Note, the valid lifetime of a + // temporary address is relative to the address's creation time. + validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) + if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { + validUntil = prefixState.validUntil + } + + // If the address is no longer valid, invalidate it immediately. Otherwise, + // reset the invalidation job. + newValidLifetime := validUntil.Sub(now) + if newValidLifetime <= 0 { + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) + continue + } + tempAddrState.invalidationJob.Cancel() + tempAddrState.invalidationJob.Schedule(newValidLifetime) + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or + // the maximum temporary address preferred lifetime - the temporary address + // desync factor. Note, the preferred lifetime of a temporary address is + // relative to the address's creation time. + preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) + if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { + preferredUntil = prefixState.preferredUntil + } + + // If the address is no longer preferred, deprecate it immediately. + // Otherwise, schedule the deprecation job again. + newPreferredLifetime := preferredUntil.Sub(now) + tempAddrState.deprecationJob.Cancel() + if newPreferredLifetime <= 0 { + ndp.deprecateSLAACAddress(tempAddrState.ref) + } else { + tempAddrState.ref.deprecated = false + tempAddrState.deprecationJob.Schedule(newPreferredLifetime) + } + + tempAddrState.regenJob.Cancel() + if tempAddrState.regenerated { + } else { + allAddressesRegenerated = false + + if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration { + // The new preferred lifetime is less than the advance regeneration + // duration so regenerate an address for this temporary address + // immediately after we finish iterating over the temporary addresses. + regenForAddr = tempAddr + } else { + tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + } + } + } + + // Generate a new temporary address if all of the existing temporary addresses + // have been regenerated, or we need to immediately regenerate an address + // due to an update in preferred lifetime. + // + // If each temporay address has already been regenerated, no new temporary + // address will be generated. To ensure continuation of temporary SLAAC + // addresses, we manually try to regenerate an address here. + if len(regenForAddr) != 0 || allAddressesRegenerated { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok { + state.regenerated = true + prefixState.tempAddrs[regenForAddr] = state + } + } +} + +// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP +// dispatcher that ref has been deprecated. +// +// deprecateSLAACAddress does nothing if ref is already deprecated. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { + if ref.deprecated { + return + } + + ref.deprecated = true + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) + } +} + +// invalidateSLAACPrefix invalidates a SLAAC prefix. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { + if r := state.stableAddr.ref; r != nil { + // Since we are already invalidating the prefix, do not invalidate the + // prefix when removing the address. + if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err)) + } + } + + ndp.cleanupSLAACPrefixResources(prefix, state) +} + +// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's +// resources. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + } + + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.address() { + return + } + + if !invalidatePrefix { + // If the prefix is not being invalidated, disassociate the address from the + // prefix and do nothing further. + state.stableAddr.ref = nil + ndp.slaacPrefixes[prefix] = state + return + } + + ndp.cleanupSLAACPrefixResources(prefix, state) +} + +// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. +// +// Panics if the SLAAC prefix is not known. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { + // Invalidate all temporary addresses. + for tempAddr, tempAddrState := range state.tempAddrs { + ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) + } + + state.stableAddr.ref = nil + state.deprecationJob.Cancel() + state.invalidationJob.Cancel() + delete(ndp.slaacPrefixes, prefix) +} + +// invalidateTempSLAACAddr invalidates a temporary SLAAC address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + // Since we are already invalidating the address, do not invalidate the + // address when removing the address. + if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err)) + } + + ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary +// SLAAC address's resources from ndp. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + } + + if !invalidateAddr { + return + } + + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr)) + } + + tempAddrState, ok := state.tempAddrs[addr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr)) + } + + ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's +// jobs and entry. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + tempAddrState.deprecationJob.Cancel() + tempAddrState.invalidationJob.Cancel() + tempAddrState.regenJob.Cancel() + delete(tempAddrs, tempAddr) +} + +// cleanupState cleans up ndp's state. +// +// If hostOnly is true, then only host-specific state will be cleaned up. +// +// cleanupState MUST be called with hostOnly set to true when ndp's NIC is +// transitioning from a host to a router. This function will invalidate all +// discovered on-link prefixes, discovered routers, and auto-generated +// addresses. +// +// If hostOnly is true, then the link-local auto-generated address will not be +// invalidated as routers are also expected to generate a link-local address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupState(hostOnly bool) { + linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() + linkLocalPrefixes := 0 + for prefix, state := range ndp.slaacPrefixes { + // RFC 4862 section 5 states that routers are also expected to generate a + // link-local address so we do not invalidate them if we are cleaning up + // host-only state. + if hostOnly && prefix == linkLocalSubnet { + linkLocalPrefixes++ + continue + } + + ndp.invalidateSLAACPrefix(prefix, state) + } + + if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { + panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) + } + + for prefix := range ndp.onLinkPrefixes { + ndp.invalidateOnLinkPrefix(prefix) + } + + if got := len(ndp.onLinkPrefixes); got != 0 { + panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)) + } + + for router := range ndp.defaultRouters { + ndp.invalidateDefaultRouter(router) + } + + if got := len(ndp.defaultRouters); got != 0 { + panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) + } + + ndp.dhcpv6Configuration = 0 +} + +// startSolicitingRouters starts soliciting routers, as per RFC 4861 section +// 6.3.7. If routers are already being solicited, this function does nothing. +// +// The NIC ndp belongs to MUST be locked. +func (ndp *ndpState) startSolicitingRouters() { + if ndp.rtrSolicit.timer != nil { + // We are already soliciting routers. + return + } + + remaining := ndp.configs.MaxRtrSolicitations + if remaining == 0 { + return + } + + // Calculate the random delay before sending our first RS, as per RFC + // 4861 section 6.3.7. + var delay time.Duration + if ndp.configs.MaxRtrSolicitationDelay > 0 { + delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) + } + + var done bool + ndp.rtrSolicit.done = &done + ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that the RS timer fired after another + // goroutine already obtained the NIC lock and stopped solicitations. + // Simply return here and do nothing further. + ndp.nic.mu.Unlock() + return + } + + // As per RFC 4861 section 4.1, the source of the RS is an address assigned + // to the sending interface, or the unspecified address if no address is + // assigned to the sending interface. + ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) + if ref == nil { + ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + } + ndp.nic.mu.Unlock() + + localAddr := ref.address() + r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + defer r.Release() + + // Route should resolve immediately since + // header.IPv6AllRoutersMulticastAddress is a multicast address so a + // remote link address can be calculated without a resolution process. + if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { + return + } + + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) + } else if c != nil { + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) + } + + // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source + // link-layer address option if the source address of the NDP RS is + // specified. This option MUST NOT be included if the source address is + // unspecified. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + var optsSerializer header.NDPOptionsSerializer + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + optsSerializer = header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + } + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + icmpData := header.ICMPv6(buffer.NewView(payloadSize)) + icmpData.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(icmpData.NDPPayload()) + rs.Options().Serialize(optsSerializer) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) + + sent := r.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil, + NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + TOS: DefaultTOS, + }, pkt, + ); err != nil { + sent.Dropped.Increment() + log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) + // Don't send any more messages if we had an error. + remaining = 0 + } else { + sent.RouterSolicit.Increment() + remaining-- + } + + ndp.nic.mu.Lock() + if done || remaining == 0 { + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil + } else if ndp.rtrSolicit.timer != nil { + // Note, we need to explicitly check to make sure that + // the timer field is not nil because if it was nil but + // we still reached this point, then we know the NIC + // was requested to stop soliciting routers so we don't + // need to send the next Router Solicitation message. + ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) + } + ndp.nic.mu.Unlock() + }) + +} + +// stopSolicitingRouters stops soliciting routers. If routers are not currently +// being solicited, this function does nothing. +// +// The NIC ndp belongs to MUST be locked. +func (ndp *ndpState) stopSolicitingRouters() { + if ndp.rtrSolicit.timer == nil { + // Nothing to do. + return + } + + *ndp.rtrSolicit.done = true + ndp.rtrSolicit.timer.Stop() + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil +} + +// initializeTempAddrState initializes state related to temporary SLAAC +// addresses. +func (ndp *ndpState) initializeTempAddrState() { + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID()) + + if MaxDesyncFactor != 0 { + ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 525a25218..1a6724c31 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,9 +15,14 @@ package stack_test import ( + "context" + "encoding/binary" + "fmt" "testing" "time" + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -26,74 +31,330 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" - linkAddr1 = "\x02\x02\x03\x04\x05\x06" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") + + // Extra time to use when waiting for an async event to occur. + defaultAsyncPositiveEventTimeout = 10 * time.Second + + // Extra time to use when waiting for an async event to not occur. + // + // Since a negative check is used to make sure an event did not happen, it is + // okay to use a smaller timeout compared to the positive case since execution + // stall in regards to the monotonic clock will not affect the expected + // outcome. + defaultAsyncNegativeEventTimeout = time.Second ) -// TestDADDisabled tests that an address successfully resolves immediately -// when DAD is not enabled (the default for an empty stack.Options). -func TestDADDisabled(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, +var ( + llAddr1 = header.LinkLocalAddr(linkAddr1) + llAddr2 = header.LinkLocalAddr(linkAddr2) + llAddr3 = header.LinkLocalAddr(linkAddr3) + llAddr4 = header.LinkLocalAddr(linkAddr4) + dstAddr = tcpip.FullAddress{ + Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + Port: 25, } +) - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) +func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return tcpip.AddressWithPrefix{} } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, } +} - // Should get the address immediately since we should not have performed - // DAD on it. - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) +// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent +// tcpip.Subnet, and an address where the lower half of the address is composed +// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. +func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { + prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixBytes), + PrefixLen: 64, } - // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { - t.Fatalf("got NeighborSolicit = %d, want = 0", got) - } + subnet := prefix.Subnet() + + return prefix, subnet, addrForSubnet(subnet, linkAddr) } // ndpDADEvent is a set of parameters that was passed to // ndpDispatcher.OnDuplicateAddressDetectionStatus. type ndpDADEvent struct { - nicid tcpip.NICID + nicID tcpip.NICID addr tcpip.Address resolved bool err *tcpip.Error } +type ndpRouterEvent struct { + nicID tcpip.NICID + addr tcpip.Address + // true if router was discovered, false if invalidated. + discovered bool +} + +type ndpPrefixEvent struct { + nicID tcpip.NICID + prefix tcpip.Subnet + // true if prefix was discovered, false if invalidated. + discovered bool +} + +type ndpAutoGenAddrEventType int + +const ( + newAddr ndpAutoGenAddrEventType = iota + deprecatedAddr + invalidatedAddr +) + +type ndpAutoGenAddrEvent struct { + nicID tcpip.NICID + addr tcpip.AddressWithPrefix + eventType ndpAutoGenAddrEventType +} + +type ndpRDNSS struct { + addrs []tcpip.Address + lifetime time.Duration +} + +type ndpRDNSSEvent struct { + nicID tcpip.NICID + rdnss ndpRDNSS +} + +type ndpDNSSLEvent struct { + nicID tcpip.NICID + domainNames []string + lifetime time.Duration +} + +type ndpDHCPv6Event struct { + nicID tcpip.NICID + configuration stack.DHCPv6ConfigurationFromNDPRA +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. type ndpDispatcher struct { - dadC chan ndpDADEvent + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent + dnsslC chan ndpDNSSLEvent + routeTable []tcpip.Route + dhcpv6ConfigurationC chan ndpDHCPv6Event } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. -// -// If the DAD event matches what we are expecting, send signal on n.dadC. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { - n.dadC <- ndpDADEvent{ - nicid, - addr, - resolved, - err, +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { + if n.dadC != nil { + n.dadC <- ndpDADEvent{ + nicID, + addr, + resolved, + err, + } + } +} + +// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ + nicID, + addr, + true, + } + } + + return n.rememberRouter +} + +// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ + nicID, + addr, + false, + } + } +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ + nicID, + prefix, + true, + } + } + + return n.rememberPrefix +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ + nicID, + prefix, + false, + } + } +} + +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + newAddr, + } + } + return true +} + +func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + deprecatedAddr, + } + } +} + +func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + invalidatedAddr, + } + } +} + +// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { + if c := n.rdnssC; c != nil { + c <- ndpRDNSSEvent{ + nicID, + ndpRDNSS{ + addrs, + lifetime, + }, + } + } +} + +// Implements stack.NDPDispatcher.OnDNSSearchListOption. +func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { + if n.dnsslC != nil { + n.dnsslC <- ndpDNSSLEvent{ + nicID, + domainNames, + lifetime, + } + } +} + +// Implements stack.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { + if c := n.dhcpv6ConfigurationC; c != nil { + c <- ndpDHCPv6Event{ + nicID, + configuration, + } + } +} + +// channelLinkWithHeaderLength is a channel.Endpoint with a configurable +// header length. +type channelLinkWithHeaderLength struct { + *channel.Endpoint + headerLength uint16 +} + +func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { + return l.headerLength +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// resolved flag set to resolved with the specified err. +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { + return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) +} + +// TestDADDisabled tests that an address successfully resolves immediately +// when DAD is not enabled (the default for an empty stack.Options). +func TestDADDisabled(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + } + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } + + // Should get the address immediately since we should not have performed + // DAD on it. + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + } + + // We should not have sent any NDP NS messages. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + t.Fatalf("got NeighborSolicit = %d, want = 0", got) } } @@ -101,23 +362,54 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, add // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. +// This tests also validates the NDP NS packet that is transmitted. func TestDADResolve(t *testing.T) { + const nicID = 1 + tests := []struct { name string + linkHeaderLen uint16 dupAddrDetectTransmits uint8 retransTimer time.Duration expectedRetransmitTimer time.Duration }{ - {"1:1s:1s", 1, time.Second, time.Second}, - {"2:1s:1s", 2, time.Second, time.Second}, - {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second}, + { + name: "1:1s:1s", + dupAddrDetectTransmits: 1, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "2:1s:1s", + linkHeaderLen: 1, + dupAddrDetectTransmits: 2, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "1:2s:2s", + linkHeaderLen: 2, + dupAddrDetectTransmits: 1, + retransTimer: 2 * time.Second, + expectedRetransmitTimer: 2 * time.Second, + }, // 0s is an invalid RetransmitTimer timer and will be fixed to // the default RetransmitTimer value of 1s. - {"1:0s:1s", 1, 0, time.Second}, + { + name: "1:0s:1s", + linkHeaderLen: 3, + dupAddrDetectTransmits: 1, + retransTimer: 0, + expectedRetransmitTimer: time.Second, + }, } for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -128,101 +420,142 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), + headerLength: test.linkHeaderLen, } - - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(opts) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit + // We add a default route so the call to FindRoute below will succeed + // once we have an assigned address. + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: addr3, + NIC: nicID, + }}) - // Should have sent an NDP NS immediately. - if got := stat.Value(); got != 1 { - t.Fatalf("got NeighborSolicit = %d, want = 1", got) + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } + // Address should not be considered bound to the NIC yet (DAD ongoing). + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - // Address should not be considered bound to the NIC yet - // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + // Make sure the address does not resolve before the resolution time has + // passed. + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + // Should not get a route even if we specify the local address as the + // tentative address. + { + r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) + if err != tcpip.ErrNoRoute { + t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + } + r.Release() } - - // Wait for the remaining time - some delta (500ms), to - // make sure the address is still not resolved. - const delta = 500 * time.Millisecond - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + { + r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) + if err != tcpip.ErrNoRoute { + t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + } + r.Release() } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + + if t.Failed() { + t.FailNow() } // Wait for DAD to resolve. select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) + if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + } + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if addr.Address != addr1 { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + } + // Should get a route using the address now that it is resolved. + { + r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) + if err != nil { + t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err) + } else if r.LocalAddress != addr1 { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) } + r.Release() } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + { + r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) + if err != nil { + t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err) + } else if r.LocalAddress != addr1 { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) + } + r.Release() } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) + + if t.Failed() { + t.FailNow() } // Should not have sent any more NS messages. - if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p := <-e.C + p, _ := e.ReadContext(context.Background()) // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - // Check NDP packet. - checker.IPv6(t, p.Header.ToVectorisedView().First(), + // Make sure the right remote link address is used. + snmc := header.SolicitedNodeAddr(addr1) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + // Check NDP NS packet. + // + // As per RFC 4861 section 4.3, a possible option is the Source Link + // Layer option, but this option MUST NOT be included when the source + // address of the packet is the unspecified address. + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), checker.NDPNS( - checker.NDPNSTargetAddress(addr1))) + checker.NDPNSTargetAddress(addr1), + checker.NDPNSOptions(nil), + )) + + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } } }) } - } // TestDADFail tests to make sure that the DAD process fails if another node is @@ -230,6 +563,8 @@ func TestDADResolve(t *testing.T) { // a node doing DAD for the same address), or if another node is detected to own // the address already (receive an NA message for the tentative address). func TestDADFail(t *testing.T) { + const nicID = 1 + tests := []struct { name string makeBuf func(tgt tcpip.Address) buffer.Prependable @@ -265,13 +600,17 @@ func TestDADFail(t *testing.T) { { "RxAdvert", func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) + pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(pkt.NDPPayload()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(tgt) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -295,7 +634,7 @@ func TestDADFail(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ @@ -305,30 +644,33 @@ func TestDADFail(t *testing.T) { } opts.NDPConfigs.RetransmitTimer = time.Second * 2 - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e.InjectInbound(header.IPv6ProtocolNumber, pkt) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -344,102 +686,132 @@ func TestDADFail(t *testing.T) { // something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if e.resolved { - t.Fatal("got DAD event w/ resolved = true, want = false") + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + + // Attempting to add the address again should not fail if the address's + // state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } }) } } -// TestDADStop tests to make sure that the DAD process stops when an address is -// removed. func TestDADStop(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - ndpConfigs := stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } + const nicID = 1 - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } + tests := []struct { + name string + stopFn func(t *testing.T, s *stack.Stack) + skipFinalAddrCheck bool + }{ + // Tests to make sure that DAD stops when an address is removed. + { + name: "Remove address", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) + } + }, + }, - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } + // Tests to make sure that DAD stops when the NIC is disabled. + { + name: "Disable NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + }, + }, - // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + // Tests to make sure that DAD stops when the NIC is removed. + { + name: "Remove NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("RemoveNIC(%d): %s", nicID, err) + } + }, + // The NIC is removed so we can't check its addresses after calling + // stopFn. + skipFinalAddrCheck: true, + }, } - // Remove the address. This should stop DAD. - if err := s.RemoveAddress(1, addr1); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + } - // Wait for DAD to fail (since the address was removed during DAD). - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if e.resolved { - t.Fatal("got DAD event w/ resolved = true, want = false") - } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Fatalf("got NeighborSolicit = %d, want <= 1", got) + // Address should not be considered bound to the NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + + test.stopFn(t, s) + + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + } + + if !test.skipFinalAddrCheck { + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + } + + // Should not have sent more than 1 NS message. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + t.Errorf("got NeighborSolicit = %d, want <= 1", got) + } + }) } } @@ -460,6 +832,10 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { // configurations without affecting the default NDP configurations or other // interfaces' configurations. func TestSetNDPConfigurations(t *testing.T) { + const nicID1 = 1 + const nicID2 = 2 + const nicID3 = 3 + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -483,25 +859,36 @@ func TestSetNDPConfigurations(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPDisp: &ndpDisp, }) + expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatalf("expected DAD event for %s", addr) + } + } + // This NIC(1)'s NDP configurations will be updated to // be different from the default. - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) } // Created before updating NIC(1)'s NDP configurations // but updating NIC(1)'s NDP configurations should not // affect other existing NICs. - if err := s.CreateNIC(2, e); err != nil { - t.Fatalf("CreateNIC(2) = %s", err) + if err := s.CreateNIC(nicID2, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) } // Update the NDP configurations on NIC(1) to use DAD. @@ -509,36 +896,38 @@ func TestSetNDPConfigurations(t *testing.T) { DupAddrDetectTransmits: test.dupAddrDetectTransmits, RetransmitTimer: test.retransmitTimer, } - if err := s.SetNDPConfigurations(1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) + if err := s.SetNDPConfigurations(nicID1, configs); err != nil { + t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err) } // Created after updating NIC(1)'s NDP configurations // but the stack's default NDP configurations should not // have been updated. - if err := s.CreateNIC(3, e); err != nil { - t.Fatalf("CreateNIC(3) = %s", err) + if err := s.CreateNIC(nicID3, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err) } // Add addresses for each NIC. - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err) } - if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) + if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err) } - if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) + expectDADEvent(nicID2, addr2) + if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err) } + expectDADEvent(nicID3, addr3) // Address should not be considered bound to NIC(1) yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) } // Should get the address on NIC(2) and NIC(3) @@ -546,31 +935,31 @@ func TestSetNDPConfigurations(t *testing.T) { // it as the stack was configured to not do DAD by // default and we only updated the NDP configurations on // NIC(1). - addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err) } if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2) } - addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err) } if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3) } // Sleep until right (500ms before) before resolution to // make sure the address didn't resolve on NIC(1) yet. const delta = 500 * time.Millisecond time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) } // Wait for DAD to resolve. @@ -584,25 +973,4386 @@ func TestSetNDPConfigurations(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) + if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + } + addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1) + } + }) + } +} + +// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options +// and DHCPv6 configurations specified. +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(0) + raPayload := pkt.NDPPayload() + ra := header.NDPRouterAdvert(raPayload) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(raPayload[2:], rl) + // Populate the Managed Address flag field. + if managedAddress { + // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) + // of the RA payload. + raPayload[1] |= (1 << 7) + } + // Populate the Other Configurations flag field. + if otherConfigurations { + // The Other Configurations flag field is the 6th bit of byte #1 + // (0-indexing) of the RA payload. + raPayload[1] |= (1 << 6) + } + opts := ra.Options() + opts.Serialize(optSer) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + iph.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) +} + +// raBufWithOpts returns a valid NDP Router Advertisement with options. +// +// Note, raBufWithOpts does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) +} + +// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related +// fields set. +// +// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the +// DHCPv6 related ones. +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +} + +// raBuf returns a valid NDP Router Advertisement. +// +// Note, raBuf does not populate any of the RA fields other than the +// Router Lifetime. +func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) +} + +// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix +// Information option. +// +// Note, raBufWithPI does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { + flags := uint8(0) + if onLink { + // The OnLink flag is the 7th bit in the flags byte. + flags |= 1 << 7 + } + if auto { + // The Address Auto-Configuration flag is the 6th bit in the + // flags byte. + flags |= 1 << 6 + } + + // A valid header.NDPPrefixInformation must be 30 bytes. + buf := [30]byte{} + // The first byte in a header.NDPPrefixInformation is the Prefix Length + // field. + buf[0] = uint8(prefix.PrefixLen) + // The 2nd byte within a header.NDPPrefixInformation is the Flags field. + buf[1] = flags + // The Valid Lifetime field starts after the 2nd byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[2:], vl) + // The Preferred Lifetime field starts after the 6th byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[6:], pl) + // The Prefix Address field starts after the 14th byte within a + // header.NDPPrefixInformation. + copy(buf[14:], prefix.Address) + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ + header.NDPPrefixInformation(buf[:]), + }) +} + +// TestNoRouterDiscovery tests that router discovery will not be performed if +// configured not to. +func TestNoRouterDiscovery(t *testing.T) { + // Being configured to discover routers means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // router discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// discovered flag set to discovered. +func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { + return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + +// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered router when the dispatcher asks it not to. +func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA for a router we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr2, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the router in the first place. + select { + case <-ndpDisp.routerC: + t.Fatal("should not have received any router events") + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + } +} + +func TestRouterDiscovery(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + } + + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } + + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from another router (lladdr3) with non-zero lifetime. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) + expectRouterEvent(llAddr3, true) + + // Rx an RA from lladdr2 with lesser lifetime. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } + + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) +} + +// TestRouterDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +func TestRouterDiscoveryMaxRouters(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA from 2 more than the max number of discovered routers. + for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + linkAddr := []byte{2, 2, 3, 4, 5, 0} + linkAddr[5] = byte(i) + llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) + + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + + if i <= stack.MaxDiscoveredDefaultRouters { + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + default: + t.Fatal("expected router discovery event") + } + + } else { + select { + case <-ndpDisp.routerC: + t.Fatal("should not have discovered a new router after we already discovered the max number of routers") + default: + } + } + } +} + +// TestNoPrefixDiscovery tests that prefix discovery will not be performed if +// configured not to. +func TestNoPrefixDiscovery(t *testing.T) { + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + + // Being configured to discover prefixes means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // prefix discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) + + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for prefix on nic with ID 1, and the +// discovered flag set to discovered. +func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { + return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + +// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered on-link prefix when the dispatcher asks it not to. +func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { + prefix, subnet, _ := prefixSubnetAddr(0, "") + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix that we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the prefix in the first place. + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have received any prefix events") + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + } +} + +func TestPrefixDiscovery(t *testing.T) { + prefix1, subnet1, _ := prefixSubnetAddr(0, "") + prefix2, subnet2, _ := prefixSubnetAddr(1, "") + prefix3, subnet3, _ := prefixSubnetAddr(2, "") + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) + + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) + + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) + + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: + } + + // Wait for prefix2's most recent invalidation job plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) +} + +func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { + // Update the infinite lifetime value to a smaller value so we can test + // that when we receive a PI with such a lifetime value, we do not + // invalidate the prefix. + const testInfiniteLifetimeSeconds = 2 + const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second + saved := header.NDPInfiniteLifetime + header.NDPInfiniteLifetime = testInfiniteLifetime + defer func() { + header.NDPInfiniteLifetime = saved + }() + + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + subnet := prefix.Subnet() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } + + // Receive an RA with prefix in an NDP Prefix Information option (PI) + // with infinite valid lifetime which should not get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + expectPrefixEvent(subnet, true) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + } + + // Receive an RA with finite lifetime. + // The prefix should get invalidated after 1s. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(testInfiniteLifetime): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive an RA with finite lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + expectPrefixEvent(subnet, true) + + // Receive an RA with prefix with an infinite lifetime. + // The prefix should not be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + } + + // Receive an RA with a prefix with a lifetime value greater than the + // set infinite lifetime value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): + } + + // Receive an RA with 0 lifetime. + // The prefix should get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) + expectPrefixEvent(subnet, false) +} + +// TestPrefixDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) + prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + + // Receive an RA with 2 more than the max number of discovered on-link + // prefixes. + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} + prefixAddr[7] = byte(i) + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixAddr[:]), + PrefixLen: 64, + } + prefixes[i] = prefix.Subnet() + buf := [30]byte{} + buf[0] = uint8(prefix.PrefixLen) + buf[1] = 128 + binary.BigEndian.PutUint32(buf[2:], 10) + copy(buf[14:], prefix.Address) + + optSer[i] = header.NDPPrefixInformation(buf[:]) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < stack.MaxDiscoveredOnLinkPrefixes { + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } else { + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") + default: + } + } + } +} + +// Checks to see if list contains an IPv6 address, item. +func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { + protocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: item, + } + + return containsAddr(list, protocolAddress) +} + +// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. +func TestNoAutoGenAddr(t *testing.T) { + prefix, _, _ := prefixSubnetAddr(0, "") + + // Being configured to auto-generate addresses means handle and + // autogen are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, autogen = + // true and forwarding = false (the required configuration to do + // SLAAC) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + autogen := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) + + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// event type is set to eventType. +func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { + return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) +} + +// TestAutoGenAddr tests that an address is properly generated and invalidated +// when configured to do so. +func TestAutoGenAddr(t *testing.T) { + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } +} + +func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { + ret := "" + for _, c := range containList { + if !containsV6Addr(addrs, c) { + ret += fmt.Sprintf("should have %s in the list of addresses\n", c) + } + } + for _, c := range notContainList { + if containsV6Addr(addrs, c) { + ret += fmt.Sprintf("should not have %s in the list of addresses\n", c) + } + } + return ret +} + +// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when +// configured to do so as part of IPv6 Privacy Extensions. +func TestAutoGenTempAddr(t *testing.T) { + const ( + nicID = 1 + newMinVL = 5 + newMinVLDuration = newMinVL * time.Second + ) + + savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate + savedMaxDesync := stack.MaxDesyncFactor + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate + stack.MaxDesyncFactor = savedMaxDesync + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + stack.MaxDesyncFactor = time.Nanosecond + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + tests := []struct { + name string + dupAddrTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD disabled", + }, + { + name: "DAD enabled", + dupAddrTransmits: 1, + retransmitTimer: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for i, test := range tests { + i := i + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred + // lifetimes. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 + } + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } + }) +} + +// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not +// generated for auto generated link-local addresses. +func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { + const nicID = 1 + + savedMaxDesyncFactor := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + }() + stack.MaxDesyncFactor = time.Nanosecond + + tests := []struct { + name string + dupAddrTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD disabled", + }, + { + name: "DAD enabled", + dupAddrTransmits: 1, + retransmitTimer: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenIPv6LinkLocal: true, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + case <-time.After(defaultAsyncNegativeEventTimeout): + } + }) + } + }) +} + +// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address +// will not be generated until after DAD completes, even if a new Router +// Advertisement is received to refresh lifetimes. +func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { + const ( + nicID = 1 + dadTransmits = 1 + retransmitTimer = 2 * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + }() + stack.MaxDesyncFactor = 0 + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Receive an RA to trigger SLAAC for prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // DAD on the stable address for prefix has not yet completed. Receiving a new + // RA that would refresh lifetimes should not generate a temporary SLAAC + // address for the prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + default: + } + + // Wait for DAD to complete for the stable address then expect the temporary + // address to be generated. + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } +} + +// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are +// regenerated. +func TestAutoGenTempAddrRegen(t *testing.T) { + const ( + nicID = 1 + regenAfter = 2 * time.Second + newMinVL = 10 + newMinVLDuration = newMinVL * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + }() + stack.MaxDesyncFactor = 0 + stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration + stack.MinMaxTempAddrValidLifetime = newMinVLDuration + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: newMinVLDuration - regenAfter, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(addr, newAddr) + expectAutoGenAddrEvent(tempAddr1, newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for regeneration + expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for regeneration + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Stop generating temporary addresses + ndpConfigs.AutoGenTempGlobalAddresses = false + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + + // Wait for all the temporary addresses to get invalidated. + tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} + invalidateAfter := newMinVLDuration - 2*regenAfter + for _, addr := range tempAddrs { + // Wait for a deprecation then invalidation event, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation jobs could execute in any + // order. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we shouldn't get a deprecation + // event after. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated event = %+v", e) + case <-time.After(defaultAsyncNegativeEventTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event = %+v", e) + } + case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + + invalidateAfter = regenAfter + } + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + t.Fatal(mismatch) + } +} + +// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's +// regeneration job gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { + const ( + nicID = 1 + regenAfter = 2 * time.Second + newMinVL = 10 + newMinVLDuration = newMinVL * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + }() + stack.MaxDesyncFactor = 0 + stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration + stack.MinMaxTempAddrValidLifetime = newMinVLDuration + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: newMinVLDuration - regenAfter, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(addr, newAddr) + expectAutoGenAddrEvent(tempAddr1, newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Deprecate the prefix. + // + // A new temporary address should be generated after the regeneration + // time has passed since the prefix is deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + expectAutoGenAddrEvent(addr, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + } + + // Prefer the prefix again. + // + // A new temporary address should immediately be generated since the + // regeneration time has already passed since the last address was generated + // - this regeneration does not depend on a job. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr2, newAddr) + + // Increase the maximum lifetimes for temporary addresses to large values + // then refresh the lifetimes of the prefix. + // + // A new address should not be generated after the regeneration time that was + // expected for the previous check. This is because the preferred lifetime for + // the temporary addresses has increased, so it will take more time to + // regenerate a new temporary address. Note, new addresses are only + // regenerated after the preferred lifetime - the regenerate advance duration + // as paased. + ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second + ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + } + + // Set the maximum lifetimes for temporary addresses such that on the next + // RA, the regeneration job gets scheduled again. + // + // The maximum lifetime is the sum of the minimum lifetimes for temporary + // addresses + the time that has already passed since the last address was + // generated so that the regeneration job is needed to generate the next + // address. + newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout + ndpConfigs.MaxTempAddrValidLifetime = newLifetimes + ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) +} + +// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response +// to a mix of DAD conflicts and NIC-local conflicts. +func TestMixedSLAACAddrConflictRegen(t *testing.T) { + const ( + nicID = 1 + nicName = "nic" + lifetimeSeconds = 9999 + // From stack.maxSLAACAddrLocalRegenAttempts + maxSLAACAddrLocalRegenAttempts = 10 + // We use 2 more addreses than the maximum local regeneration attempts + // because we want to also trigger regeneration in response to a DAD + // conflicts for this test. + maxAddrs = maxSLAACAddrLocalRegenAttempts + 2 + dupAddrTransmits = 1 + retransmitTimer = time.Second + ) + + var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte + header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID) + + var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte + header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID) + + prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1) + var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix + var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix + var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix + addrBytes := []byte(subnet.ID()) + for i := 0; i < maxAddrs; i++ { + stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)), + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + // When generating temporary addresses, the resolved stable address for the + // SLAAC prefix will be the first address stable address generated for the + // prefix as we will not simulate address conflicts for the stable addresses + // in tests involving temporary addresses. Address conflicts for stable + // addresses will be done in their own tests. + tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address) + tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address) + } + + tests := []struct { + name string + addrs []tcpip.AddressWithPrefix + tempAddrs bool + initialExpect tcpip.AddressWithPrefix + nicNameFromID func(tcpip.NICID, string) string + }{ + { + name: "Stable addresses with opaque IIDs", + addrs: stableAddrsWithOpaqueIID[:], + nicNameFromID: func(tcpip.NICID, string) string { + return nicName + }, + }, + { + name: "Temporary addresses with opaque IIDs", + addrs: tempAddrsWithOpaqueIID[:], + tempAddrs: true, + initialExpect: stableAddrsWithOpaqueIID[0], + nicNameFromID: func(tcpip.NICID, string) string { + return nicName + }, + }, + { + name: "Temporary addresses with modified EUI64", + addrs: tempAddrsWithModifiedEUI64[:], + tempAddrs: true, + initialExpect: stableAddrWithModifiedEUI64, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: test.tempAddrs, + AutoGenAddressConflictRetries: 1, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: test.nicNameFromID, + }, + }) + + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr2, + NIC: nicID, + }}) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + for j := 0; j < len(test.addrs)-1; j++ { + // The NIC will not attempt to generate an address in response to a + // NIC-local conflict after some maximum number of attempts. We skip + // creating a conflict for the address that would be generated as part + // of the last attempt so we can simulate a DAD conflict for this + // address and restart the NIC-local generation process. + if j == maxSLAACAddrLocalRegenAttempts-1 { + continue + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) + } + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } + + // Enable DAD. + ndpDisp.dadC = make(chan ndpDADEvent, 2) + ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits + ndpConfigs.RetransmitTimer = retransmitTimer + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + + // Do SLAAC for prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + if test.initialExpect != (tcpip.AddressWithPrefix{}) { + expectAutoGenAddrEvent(test.initialExpect, newAddr) + expectDADEventAsync(test.initialExpect.Address) + } + + // The last local generation attempt should succeed, but we introduce a + // DAD failure to restart the local generation process. + addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] + expectAutoGenAddrAsyncEvent(addr, newAddr) + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + + // The last address generated should resolve DAD. + addr = test.addrs[len(test.addrs)-1] + expectAutoGenAddrAsyncEvent(addr, newAddr) + expectDADEventAsync(addr.Address) + + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + default: + } + }) + } +} + +// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, +// channel.Endpoint and stack.Stack. +// +// stack.Stack will have a default route through the router (llAddr3) installed +// and a static link-address (linkAddr3) added to the link address cache for the +// router. +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { + t.Helper() + ndpDisp := &ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + return ndpDisp, e, s +} + +// addrForNewConnectionTo returns the local address used when creating a new +// connection to addr. +func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + if err := ep.Connect(addr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", addr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// addrForNewConnection returns the local address used when creating a new +// connection. +func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { + t.Helper() + + return addrForNewConnectionTo(t, s, dstAddr) +} + +// addrForNewConnectionWithAddr returns the local address used when creating a +// new connection with a specific local address. +func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + if err := ep.Bind(addr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", addr, err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when +// receiving a PI with 0 preferred lifetime. +func TestAutoGenAddrDeprecateFromPI(t *testing.T) { + const nicID = 1 + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) + + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) + + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } + + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } + + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) +} + +// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated +// when its preferred lifetime expires. +func TestAutoGenAddrJobDeprecation(t *testing.T) { + const nicID = 1 + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) + + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } + + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } + + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } + + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultAsyncNegativeEventTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event") + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); got != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + + if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) + } +} + +// Tests transitioning a SLAAC address's valid lifetime between finite and +// infinite values. +func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { + const infiniteVLSeconds = 2 + const minVLSeconds = 1 + savedIL := header.NDPInfiniteLifetime + savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + header.NDPInfiniteLifetime = savedIL + }() + stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + infiniteVL uint32 + }{ + { + name: "EqualToInfiniteVL", + infiniteVL: infiniteVLSeconds, + }, + // Our implementation supports changing header.NDPInfiniteLifetime for tests + // such that a packet can be received where the lifetime field has a value + // greater than header.NDPInfiniteLifetime. Because of this, we test to make + // sure that receiving a value greater than header.NDPInfiniteLifetime is + // handled the same as when receiving a value equal to + // header.NDPInfiniteLifetime. + { + name: "MoreThanInfiniteVL", + infiniteVL: infiniteVLSeconds + 1, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + +// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an +// auto-generated address only gets updated when required to, as specified in +// RFC 4862 section 5.5.3.e. +func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { + const infiniteVL = 4294967295 + const newMinVL = 4 + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + ovl uint32 + nvl uint32 + evl uint32 + }{ + // Should update the VL to the minimum VL for updating if the + // new VL is less than newMinVL but was originally greater than + // it. + { + "LargeVLToVLLessThanMinVLForUpdate", + 9999, + 1, + newMinVL, + }, + { + "LargeVLTo0", + 9999, + 0, + newMinVL, + }, + { + "InfiniteVLToVLLessThanMinVLForUpdate", + infiniteVL, + 1, + newMinVL, + }, + { + "InfiniteVLTo0", + infiniteVL, + 0, + newMinVL, + }, + + // Should not update VL if original VL was less than newMinVL + // and the new VL is also less than newMinVL. + { + "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", + newMinVL - 1, + newMinVL - 3, + newMinVL - 1, + }, + + // Should take the new VL if the new VL is greater than the + // remaining time or is greater than newMinVL. + { + "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", + newMinVL + 5, + newMinVL + 3, + newMinVL + 3, + }, + { + "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", + newMinVL - 3, + newMinVL - 1, + newMinVL - 1, + }, + { + "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", + newMinVL - 3, + newMinVL + 1, + newMinVL + 1, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + + // + // Validate that the VL for the address got set + // to test.evl. + // + + // The address should not be invalidated until the effective valid + // lifetime has passed. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): + } + + // Wait for the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + +// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed +// by the user, its resources will be cleaned up and an invalidation event will +// be sent to the integrator. +func TestAutoGenAddrRemoval(t *testing.T) { + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive a PI to auto-generate an address. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr, newAddr) + + // Removing the address should result in an invalidation event + // immediately. + if err := s.RemoveAddress(1, addr.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + + // Wait for the original valid lifetime to make sure the original job got + // cancelled/cleaned up. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + } +} + +// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously +// assigned to the NIC but is in the permanentExpired state. +func TestAutoGenAddrAfterRemoval(t *testing.T) { + const nicID = 1 + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) + + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) + + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) + + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) + + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) +} + +// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that +// is already assigned to the NIC, the static address remains. +func TestAutoGenAddrStaticConflict(t *testing.T) { + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Add the address as a static address before SLAAC tries to add it. + if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive a PI where the generated address will be the same as the one + // that we already have assigned statically. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") + default: + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Should not get an invalidation event after the PI's invalidation + // time. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } +} + +// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use +// opaque interface identifiers when configured to do so. +func TestAutoGenAddrWithOpaqueIID(t *testing.T) { + const nicID = 1 + const nicName = "nic1" + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) + // addr1 and addr2 are the addresses that are expected to be generated when + // stack.Stack is configured to generate opaque interface identifiers as + // defined by RFC 7217. + addrBytes := []byte(subnet1.ID()) + addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), + PrefixLen: 64, + } + addrBytes = []byte(subnet2.ID()) + addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), + PrefixLen: 64, + } + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in a PI. + const validLifetimeSecondPrefix1 = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in a PI with a large valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } +} + +func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { + const nicID = 1 + const nicName = "nic" + const dadTransmits = 1 + const retransmitTimer = time.Second + const maxMaxRetries = 3 + const lifetimeSeconds = 10 + + // Needed for the temporary address sub test. + savedMaxDesync := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesync + }() + stack.MaxDesyncFactor = time.Nanosecond + + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix { + addrBytes := []byte(subnet.ID()) + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)), + PrefixLen: 64, + } + } + + expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + } + + expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } + + stableAddrForTempAddrTest := addrForSubnet(subnet, 0) + + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + return nil + + }, + addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { + return addrForSubnet(subnet, dadCounter) + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + autoGenLinkLocal: true, + prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + return nil + }, + addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { + return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter) + }, + }, + { + name: "Temporary address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + header.InitialTempIID(tempIIDHistory, nil, nicID) + + // Generate a stable SLAAC address so temporary addresses will be + // generated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) + expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true) + + // The stable address will be assigned throughout the test. + return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} + }, + addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address) + }, + }, + } + + for _, addrType := range addrTypes { + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the parallel + // tests complete and limit the number of parallel tests running at the same + // time to reduce flakes. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run(addrType.name, func(t *testing.T) { + for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { + for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { + maxRetries := maxRetries + numFailures := numFailures + addrType := addrType + + t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := addrType.ndpConfigs + ndpConfigs.AutoGenAddressConflictRetries = maxRetries + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + var tempIIDHistory [header.IIDSize]byte + stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + + // Simulate DAD conflicts so the address is regenerated. + for i := uint8(0); i < numFailures; i++ { + addr := addrType.addrGenFn(i, tempIIDHistory[:]) + expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + + // Should not have any new addresses assigned to the NIC. + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) + expectDADEvent(t, &ndpDisp, addr.Address, false) + + // Attempting to add the address manually should not fail if the + // address's state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + } + if err := s.RemoveAddress(nicID, addr.Address); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) + } + expectDADEvent(t, &ndpDisp, addr.Address, false) + } + + // Should not have any new addresses assigned to the NIC. + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // If we had less failures than generation attempts, we should have + // an address after DAD resolves. + if maxRetries+1 > numFailures { + addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) + expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, &ndpDisp, addr.Address, true) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { + t.Fatal(mismatch) + } + } + + // Should not attempt address generation again. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncNegativeEventTimeout): + } + }) + } + } + }) + } +} + +// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is +// not made for SLAAC addresses generated with an IID based on the NIC's link +// address. +func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { + const nicID = 1 + const dadTransmits = 1 + const retransmitTimer = time.Second + const maxRetries = 3 + const lifetimeSeconds = 10 + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + subnet tcpip.Subnet + triggerSLAACFn func(e *channel.Endpoint) + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + subnet: subnet, + triggerSLAACFn: func(e *channel.Endpoint) { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + AutoGenAddressConflictRetries: maxRetries, + }, + autoGenLinkLocal: true, + subnet: header.IPv6LinkLocalPrefix.Subnet(), + triggerSLAACFn: func(e *channel.Endpoint) {}, + }, + } + + for _, addrType := range addrTypes { + addrType := addrType + + t.Run(addrType.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + addrType.triggerSLAACFn(e) + + addrBytes := []byte(addrType.subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Should not attempt address regeneration. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncNegativeEventTimeout): + } + }) + } +} + +// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address +// generation in response to DAD conflicts does not refresh the lifetimes. +func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { + const nicID = 1 + const nicName = "nic" + const dadTransmits = 1 + const retransmitTimer = 2 * time.Second + const failureTimer = time.Second + const maxRetries = 1 + const lifetimeSeconds = 5 + + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + addrBytes := []byte(subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Simulate a DAD conflict after some time has passed. + time.Sleep(failureTimer) + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Let the next address resolve. + addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) + expectAutoGenAddrEvent(addr, newAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + // Address should be deprecated/invalidated after the lifetime expires. + // + // Note, the remaining lifetime is calculated from when the PI was first + // processed. Since we wait for some time before simulating a DAD conflict + // and more time for the new address to resolve, the new address is only + // expected to be valid for the remaining time. The DAD conflict should + // not have reset the lifetimes. + // + // We expect either just the invalidation event or the deprecation event + // followed by the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if e.eventType == deprecatedAddr { + if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") + } + } else { + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + } + case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for auto gen addr event") + } +} + +// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event +// to the integrator when an RA is received with the NDP Recursive DNS Server +// option with at least one valid address. +func TestNDPRecursiveDNSServerDispatch(t *testing.T) { + tests := []struct { + name string + opt header.NDPRecursiveDNSServer + expected *ndpRDNSS + }{ + { + "Unspecified", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }), + nil, + }, + { + "Multicast", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }), + nil, + }, + { + "OptionTooSmall", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, + }), + nil, + }, + { + "0Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + }), + nil, + }, + { + "Valid1Address", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + }, + 2 * time.Second, + }, + }, + { + "Valid2Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + }, + time.Second, + }, + }, + { + "Valid3Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", + }, + 0, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + // We do not expect more than a single RDNSS + // event at any time for this test. + rdnssC: make(chan ndpRDNSSEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) + + if test.expected != nil { + select { + case e := <-ndpDisp.rdnssC: + if e.nicID != 1 { + t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) + } + if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { + t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) + } + if e.rdnss.lifetime != test.expected.lifetime { + t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) + } + default: + t.Fatal("expected an RDNSS option event") + } + } + + // Should have no more RDNSS options. + select { + case e := <-ndpDisp.rdnssC: + t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) + default: + } + }) + } +} + +// TestNDPDNSSearchListDispatch tests that the integrator is informed when an +// NDP DNS Search List option is received with at least one domain name in the +// search list. +func TestNDPDNSSearchListDispatch(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dnsslC: make(chan ndpDNSSLEvent, 3), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + optSer := header.NDPOptionsSerializer{ + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 0, 0, + 2, 'h', 'i', + 0, + }), + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 'i', + 0, + 2, 'a', 'm', + 2, 'm', 'e', + 0, + }), + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 1, 0, + 3, 'x', 'y', 'z', + 0, + 5, 'h', 'e', 'l', 'l', 'o', + 5, 'w', 'o', 'r', 'l', 'd', + 0, + 4, 't', 'h', 'i', 's', + 2, 'i', 's', + 1, 'a', + 4, 't', 'e', 's', 't', + 0, + }), + } + expected := []struct { + domainNames []string + lifetime time.Duration + }{ + { + domainNames: []string{ + "hi", + }, + lifetime: 0, + }, + { + domainNames: []string{ + "i", + "am.me", + }, + lifetime: time.Second, + }, + { + domainNames: []string{ + "xyz", + "hello.world", + "this.is.a.test", + }, + lifetime: 256 * time.Second, + }, + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + + for i, expected := range expected { + select { + case dnssl := <-ndpDisp.dnsslC: + if dnssl.nicID != nicID { + t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID) + } + if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" { + t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff) + } + if dnssl.lifetime != expected.lifetime { + t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime) + } + default: + t.Fatal("expected a DNSSL event") + } + } + + // Should have no more DNSSL options. + select { + case <-ndpDisp.dnsslC: + t.Fatal("unexpectedly got a DNSSL event") + default: + } +} + +// TestCleanupNDPState tests that all discovered routers and prefixes, and +// auto-generated addresses are invalidated when a NIC becomes a router. +func TestCleanupNDPState(t *testing.T) { + const ( + lifetimeSeconds = 5 + maxRouterAndPrefixEvents = 4 + nicID1 = 1 + nicID2 = 2 + ) + + prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) + e2Addr1 := addrForSubnet(subnet1, linkAddr2) + e2Addr2 := addrForSubnet(subnet2, linkAddr2) + llAddrWithPrefix1 := tcpip.AddressWithPrefix{ + Address: llAddr1, + PrefixLen: 64, + } + llAddrWithPrefix2 := tcpip.AddressWithPrefix{ + Address: llAddr2, + PrefixLen: 64, + } + + tests := []struct { + name string + cleanupFn func(t *testing.T, s *stack.Stack) + keepAutoGenLinkLocal bool + maxAutoGenAddrEvents int + skipFinalAddrCheck bool + }{ + // A NIC should still keep its auto-generated link-local address when + // becoming a router. + { + name: "Enable forwarding", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(ipv6.ProtocolNumber, true) + }, + keepAutoGenLinkLocal: true, + maxAutoGenAddrEvents: 4, + }, + + // A NIC should cleanup all NDP state when it is disabled. + { + name: "Disable NIC", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + }, + + // A NIC should cleanup all NDP state when it is removed. + { + name: "Remove NIC", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + + if err := s.RemoveNIC(nicID1); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err) + } + if err := s.RemoveNIC(nicID2); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + // The NICs are removed so we can't check their addresses after calling + // stopFn. + skipFinalAddrCheck: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + expectRouterEvent := func() (bool, ndpRouterEvent) { + select { + case e := <-ndpDisp.routerC: + return true, e + default: + } + + return false, ndpRouterEvent{} + } + + expectPrefixEvent := func() (bool, ndpPrefixEvent) { + select { + case e := <-ndpDisp.prefixC: + return true, e + default: + } + + return false, ndpPrefixEvent{} + } + + expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { + select { + case e := <-ndpDisp.autoGenAddrC: + return true, e + default: + } + + return false, ndpAutoGenAddrEvent{} + } + + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } + // We have other tests that make sure we receive the *correct* events + // on normal discovery of routers/prefixes, and auto-generated + // addresses. Here we just make sure we get an event and let other tests + // handle the correctness check. + expectAutoGenAddrEvent() + + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + expectAutoGenAddrEvent() + + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and + // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from + // llAddr4) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) + } + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) + } + + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) + } + + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) + } + + // We should have the auto-generated addresses added. + nicinfo := s.NICInfo() + nic1Addrs := nicinfo[nicID1].ProtocolAddresses + nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // We can't proceed any further if we already failed the test (missing + // some discovery/auto-generated address events or addresses). + if t.Failed() { + t.FailNow() + } + + test.cleanupFn(t, s) + + // Collect invalidation events after having NDP state cleaned up. + gotRouterEvents := make(map[ndpRouterEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectRouterEvent() + if !ok { + t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotRouterEvents[e]++ + } + gotPrefixEvents := make(map[ndpPrefixEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectPrefixEvent() + if !ok { + t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotPrefixEvents[e]++ + } + gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) + for i := 0; i < test.maxAutoGenAddrEvents; i++ { + ok, e := expectAutoGenAddrEvent() + if !ok { + t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) + break + } + gotAutoGenAddrEvents[e]++ + } + + // No need to proceed any further if we already failed the test (missing + // some invalidation events). + if t.Failed() { + t.FailNow() + } + + expectedRouterEvents := map[ndpRouterEvent]int{ + {nicID: nicID1, addr: llAddr3, discovered: false}: 1, + {nicID: nicID1, addr: llAddr4, discovered: false}: 1, + {nicID: nicID2, addr: llAddr3, discovered: false}: 1, + {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + } + if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { + t.Errorf("router events mismatch (-want +got):\n%s", diff) + } + expectedPrefixEvents := map[ndpPrefixEvent]int{ + {nicID: nicID1, prefix: subnet1, discovered: false}: 1, + {nicID: nicID1, prefix: subnet2, discovered: false}: 1, + {nicID: nicID2, prefix: subnet1, discovered: false}: 1, + {nicID: nicID2, prefix: subnet2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { + t.Errorf("prefix events mismatch (-want +got):\n%s", diff) + } + expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ + {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, + } + + if !test.keepAutoGenLinkLocal { + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 + } + + if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { + t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) + } + + if !test.skipFinalAddrCheck { + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + } + if containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + } + if containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + } + + // Should not get any more events (invalidation timers should have been + // cancelled when the NDP state was cleaned up). + time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) + select { + case <-ndpDisp.routerC: + t.Error("unexpected router event") + default: + } + select { + case <-ndpDisp.prefixC: + t.Error("unexpected prefix event") + default: + } + select { + case <-ndpDisp.autoGenAddrC: + t.Error("unexpected auto-generated address event") + default: + } + }) + } +} + +// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly +// informed when new information about what configurations are available via +// DHCPv6 is learned. +func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + t.Helper() + select { + case e := <-ndpDisp.dhcpv6ConfigurationC: + if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { + t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DHCPv6 configuration event") + } + } + + expectNoDHCPv6Event := func() { + t.Helper() + select { + case <-ndpDisp.dhcpv6ConfigurationC: + t.Fatal("unexpected DHCPv6 configuration event") + default: + } + } + + // Even if the first RA reports no DHCPv6 configurations are available, the + // dispatcher should get an event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + // Receiving the same update again should not result in an event to the + // dispatcher. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to none. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + // + // Note, when the M flag is set, the O flag is redundant. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + // Even though the DHCPv6 flags are different, the effective configuration is + // the same so we should not receive a new event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() + + // Cycling the NIC should cause the last DHCPv6 configuration to be cleared. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() +} + +// TestRouterSolicitation tests the initial Router Solicitations that are sent +// when a NIC newly becomes enabled. +func TestRouterSolicitation(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + linkHeaderLen uint16 + linkAddr tcpip.LinkAddress + nicAddr tcpip.Address + expectedSrcAddr tcpip.Address + expectedNDPOpts []header.NDPOption + maxRtrSolicit uint8 + rtrSolicitInt time.Duration + effectiveRtrSolicitInt time.Duration + maxRtrSolicitDelay time.Duration + effectiveMaxRtrSolicitDelay time.Duration + }{ + { + name: "Single RS with 2s delay and interval", + expectedSrcAddr: header.IPv6Any, + maxRtrSolicit: 1, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, + maxRtrSolicitDelay: 2 * time.Second, + effectiveMaxRtrSolicitDelay: 2 * time.Second, + }, + { + name: "Single RS with 4s delay and interval", + expectedSrcAddr: header.IPv6Any, + maxRtrSolicit: 1, + rtrSolicitInt: 4 * time.Second, + effectiveRtrSolicitInt: 4 * time.Second, + maxRtrSolicitDelay: 4 * time.Second, + effectiveMaxRtrSolicitDelay: 4 * time.Second, + }, + { + name: "Two RS with delay", + linkHeaderLen: 1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, + maxRtrSolicit: 2, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, + maxRtrSolicitDelay: 500 * time.Millisecond, + effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, + }, + { + name: "Single RS without delay", + linkHeaderLen: 2, + linkAddr: linkAddr1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, + expectedNDPOpts: []header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }, + maxRtrSolicit: 1, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Two RS without delay and invalid zero interval", + linkHeaderLen: 3, + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, + maxRtrSolicit: 2, + rtrSolicitInt: 0, + effectiveRtrSolicitInt: 4 * time.Second, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Three RS without delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, + maxRtrSolicit: 3, + rtrSolicitInt: 500 * time.Millisecond, + effectiveRtrSolicitInt: 500 * time.Millisecond, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Two RS with invalid negative delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, + maxRtrSolicit: 2, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: -3 * time.Second, + effectiveMaxRtrSolicitDelay: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + p, ok := e.ReadContext(ctx) + if !ok { + t.Fatal("timed out waiting for packet") + return + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) + + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet") + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } + + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) + remaining-- + } + + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) + waitForPkt(defaultAsyncPositiveEventTimeout) + } else { + waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) + } + } + + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) + } + + // Make sure the counter got properly + // incremented. + if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) + } + }) +} + +func TestStopStartSolicitingRouters(t *testing.T) { + const nicID = 1 + const delay = 0 + const interval = 500 * time.Millisecond + const maxRtrSolicitations = 3 + + tests := []struct { + name string + startFn func(t *testing.T, s *stack.Stack) + // first is used to tell stopFn that it is being called for the first time + // after router solicitations were last enabled. + stopFn func(t *testing.T, s *stack.Stack, first bool) + }{ + // Tests that when forwarding is enabled or disabled, router solicitations + // are stopped or started, respectively. + { + name: "Enable and disable forwarding", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(ipv6.ProtocolNumber, false) + }, + stopFn: func(t *testing.T, s *stack.Stack, _ bool) { + t.Helper() + s.SetForwarding(ipv6.ProtocolNumber, true) + }, + }, + + // Tests that when a NIC is enabled or disabled, router solicitations + // are started or stopped, respectively. + { + name: "Enable and disable NIC", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + }, + stopFn: func(t *testing.T, s *stack.Stack, _ bool) { + t.Helper() + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + }, + }, + + // Tests that when a NIC is removed, router solicitations are stopped. We + // cannot start router solications on a removed NIC. + { + name: "Remove NIC", + stopFn: func(t *testing.T, s *stack.Stack, first bool) { + t.Helper() + + // Only try to remove the NIC the first time stopFn is called since it's + // impossible to remove an already removed NIC. + if !first { + return + } + + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(maxRtrSolicitations, 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + p, ok := e.ReadContext(ctx) + if !ok { + t.Fatal("timed out waiting for packet") + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Stop soliciting routers. + test.stopFn(t, s, true /* first */) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + // A single RS may have been sent before solicitations were stopped. + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) + defer cancel() + if _, ok = e.ReadContext(ctx); ok { + t.Fatal("should not have sent more than one RS message") + } + } + + // Stopping router solicitations after it has already been stopped should + // do nothing. + test.stopFn(t, s, false /* first */) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") + } + + // If test.startFn is nil, there is no way to restart router solications. + if test.startFn == nil { + return + } + + // Start soliciting routers. + test.startFn(t, s) + waitForPkt(delay + defaultAsyncPositiveEventTimeout) + waitForPkt(interval + defaultAsyncPositiveEventTimeout) + waitForPkt(interval + defaultAsyncPositiveEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + } + + // Starting router solicitations after it has already completed should do + // nothing. + test.startFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after finishing router solicitations") } }) } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go new file mode 100644 index 000000000..27e1feec0 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -0,0 +1,333 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const neighborCacheSize = 512 // max entries per interface + +// neighborCache maps IP addresses to link addresses. It uses the Least +// Recently Used (LRU) eviction strategy to implement a bounded cache for +// dynmically acquired entries. It contains the state machine and configuration +// for running Neighbor Unreachability Detection (NUD). +// +// There are two types of entries in the neighbor cache: +// 1. Dynamic entries are discovered automatically by neighbor discovery +// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm +// reachability with the device once the entry's state becomes Stale. +// 2. Static entries are explicitly added by a user and have no expiration. +// Their state is always Static. The amount of static entries stored in the +// cache is unbounded. +// +// neighborCache implements NUDHandler. +type neighborCache struct { + nic *NIC + state *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + cache map[tcpip.Address]*neighborEntry + dynamic struct { + lru neighborEntryList + + // count tracks the amount of dynamic entries in the cache. This is + // needed since static entries do not count towards the LRU cache + // eviction strategy. + count uint16 + } +} + +var _ NUDHandler = (*neighborCache)(nil) + +// getOrCreateEntry retrieves a cache entry associated with addr. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[remoteAddr]; ok { + entry.mu.RLock() + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.lru.PushFront(entry) + } + entry.mu.RUnlock() + return entry + } + + // The entry that needs to be created must be dynamic since all static + // entries are directly added to the cache via addStaticEntry. + entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + if n.dynamic.count == neighborCacheSize { + e := n.dynamic.lru.Back() + e.mu.Lock() + + delete(n.cache, e.neigh.Addr) + n.dynamic.lru.Remove(e) + n.dynamic.count-- + + e.dispatchRemoveEventLocked() + e.setStateLocked(Unknown) + e.notifyWakersLocked() + e.mu.Unlock() + } + n.cache[remoteAddr] = entry + n.dynamic.lru.PushFront(entry) + n.dynamic.count++ + return entry +} + +// entry looks up the neighbor cache for translating address to link address +// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there +// is a LinkAddressResolver registered with the network protocol, the cache +// attempts to resolve the address and returns ErrWouldBlock. If a Waker is +// provided, it will be notified when address resolution is complete (success +// or not). +// +// If address resolution is required, ErrNoLinkAddress and a notification +// channel is returned for the top level caller to block. Channel is closed +// once address resolution is complete (success or not). +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { + e := NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + } + return e, nil, nil + } + + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry.mu.Lock() + defer entry.mu.Unlock() + + switch s := entry.neigh.State; s { + case Reachable, Static: + return entry.neigh, nil, nil + + case Unknown, Incomplete, Stale, Delay, Probe: + entry.addWakerLocked(w) + + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.neigh, nil, tcpip.ErrNoLinkAddress + } + entry.done = make(chan struct{}) + } + + entry.handlePacketQueuedLocked() + return entry.neigh, entry.done, tcpip.ErrWouldBlock + + case Failed: + return entry.neigh, nil, tcpip.ErrNoLinkAddress + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", s)) + } +} + +// removeWaker removes a waker that has been added when link resolution for +// addr was requested. +func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { + n.mu.Lock() + if entry, ok := n.cache[addr]; ok { + delete(entry.wakers, waker) + } + n.mu.Unlock() +} + +// entries returns all entries in the neighbor cache. +func (n *neighborCache) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(n.cache)) + n.mu.RLock() + for _, entry := range n.cache { + entry.mu.RLock() + entries = append(entries, entry.neigh) + entry.mu.RUnlock() + } + n.mu.RUnlock() + return entries +} + +// addStaticEntry adds a static entry to the neighbor cache, mapping an IP +// address to a link address. If a dynamic entry exists in the neighbor cache +// with the same address, it will be replaced with this static entry. If a +// static entry exists with the same address but different link address, it +// will be updated with the new link address. If a static entry exists with the +// same address and link address, nothing will happen. +func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[addr]; ok { + entry.mu.Lock() + if entry.neigh.State != Static { + // Dynamic entry found with the same address. + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } else if entry.neigh.LinkAddr == linkAddr { + // Static entry found with the same address and link address. + entry.mu.Unlock() + return + } else { + // Static entry found with the same address but different link address. + entry.neigh.LinkAddr = linkAddr + entry.dispatchChangeEventLocked(entry.neigh.State) + entry.mu.Unlock() + return + } + + // Notify that resolution has been interrupted, just in case the entry was + // in the Incomplete or Probe state. + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = entry +} + +// removeEntryLocked removes the specified entry from the neighbor cache. +func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + if entry.neigh.State != Failed { + entry.dispatchRemoveEventLocked() + } + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + + delete(n.cache, entry.neigh.Addr) +} + +// removeEntry removes a dynamic or static entry by address from the neighbor +// cache. Returns true if the entry was found and deleted. +func (n *neighborCache) removeEntry(addr tcpip.Address) bool { + n.mu.Lock() + defer n.mu.Unlock() + + entry, ok := n.cache[addr] + if !ok { + return false + } + + entry.mu.Lock() + defer entry.mu.Unlock() + + n.removeEntryLocked(entry) + return true +} + +// clear removes all dynamic and static entries from the neighbor cache. +func (n *neighborCache) clear() { + n.mu.Lock() + defer n.mu.Unlock() + + for _, entry := range n.cache { + entry.mu.Lock() + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + n.dynamic.lru = neighborEntryList{} + n.cache = make(map[tcpip.Address]*neighborEntry) + n.dynamic.count = 0 +} + +// config returns the NUD configuration. +func (n *neighborCache) config() NUDConfigurations { + return n.state.Config() +} + +// setConfig changes the NUD configuration. +// +// If config contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *neighborCache) setConfig(config NUDConfigurations) { + config.resetInvalidFields() + n.state.SetConfig(config) +} + +// HandleProbe implements NUDHandler.HandleProbe by following the logic defined +// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled +// by the caller. +func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry.mu.Lock() + entry.handleProbeLocked(remoteLinkAddr) + entry.mu.Unlock() +} + +// HandleConfirmation implements NUDHandler.HandleConfirmation by following the +// logic defined in RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce cryptographically +// generated addresses, as defined in RFC 3972, Cryptographically Generated +// Addresses (CGA). This ensures that the claimed source of an NDP message is +// the owner of the claimed address. +func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleConfirmationLocked(linkAddr, flags) + entry.mu.Unlock() + } + // The confirmation SHOULD be silently discarded if the recipient did not + // initiate any communication with the target. This is indicated if there is + // no matching entry for the remote address. +} + +// HandleUpperLevelConfirmation implements +// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in +// RFC 4861 section 7.3.1. +func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleUpperLevelConfirmationLocked() + entry.mu.Unlock() + } +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go new file mode 100644 index 000000000..b4fa69e3e --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -0,0 +1,1726 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "math/rand" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // entryStoreSize is the default number of entries that will be generated and + // added to the entry store. This number needs to be larger than the size of + // the neighbor cache to give ample opportunity for verifying behavior during + // cache overflows. Four times the size of the neighbor cache allows for + // three complete cache overflows. + entryStoreSize = 4 * neighborCacheSize + + // typicalLatency is the typical latency for an ARP or NDP packet to travel + // to a router and back. + typicalLatency = time.Millisecond + + // testEntryBroadcastAddr is a special address that indicates a packet should + // be sent to all nodes. + testEntryBroadcastAddr = tcpip.Address("broadcast") + + // testEntryLocalAddr is the source address of neighbor probes. + testEntryLocalAddr = tcpip.Address("local_addr") + + // testEntryBroadcastLinkAddr is a special link address sent back to + // multicast neighbor probes. + testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") + + // infiniteDuration indicates that a task will not occur in our lifetime. + infiniteDuration = time.Duration(math.MaxInt64) +) + +// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor +// entries. The UpdatedAt field is ignored due to a lack of a deterministic +// method to predict the time that an event will be dispatched. +func entryDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + } +} + +// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to +// sort slices of entries for cases where ordering must be ignored. +func entryDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { + config.resetInvalidFields() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return &neighborCache{ + nic: &NIC{ + stack: &Stack{ + clock: clock, + nudDisp: nudDisp, + }, + id: 1, + }, + state: NewNUDState(config, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } +} + +// testEntryStore contains a set of IP to NeighborEntry mappings. +type testEntryStore struct { + mu sync.RWMutex + entriesMap map[tcpip.Address]NeighborEntry +} + +func toAddress(i int) tcpip.Address { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint16(i)) + return tcpip.Address(buf.String()) +} + +func toLinkAddress(i int) tcpip.LinkAddress { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint32(i)) + return tcpip.LinkAddress(buf.String()) +} + +// newTestEntryStore returns a testEntryStore pre-populated with entries. +func newTestEntryStore() *testEntryStore { + store := &testEntryStore{ + entriesMap: make(map[tcpip.Address]NeighborEntry), + } + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + linkAddr := toLinkAddress(i) + + store.entriesMap[addr] = NeighborEntry{ + Addr: addr, + LocalAddr: testEntryLocalAddr, + LinkAddr: linkAddr, + } + } + return store +} + +// size returns the number of entries in the store. +func (s *testEntryStore) size() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.entriesMap) +} + +// entry returns the entry at index i. Returns an empty entry and false if i is +// out of bounds. +func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { + return s.entryByAddr(toAddress(i)) +} + +// entryByAddr returns the entry matching addr for situations when the index is +// not available. Returns an empty entry and false if no entries match addr. +func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + entry, ok := s.entriesMap[addr] + return entry, ok +} + +// entries returns all entries in the store. +func (s *testEntryStore) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(s.entriesMap)) + s.mu.RLock() + defer s.mu.RUnlock() + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + if entry, ok := s.entriesMap[addr]; ok { + entries = append(entries, entry) + } + } + return entries +} + +// set modifies the link addresses of an entry. +func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { + addr := toAddress(i) + s.mu.Lock() + defer s.mu.Unlock() + if entry, ok := s.entriesMap[addr]; ok { + entry.LinkAddr = linkAddr + s.entriesMap[addr] = entry + } +} + +// testNeighborResolver implements LinkAddressResolver to emulate sending a +// neighbor probe. +type testNeighborResolver struct { + clock tcpip.Clock + neigh *neighborCache + entries *testEntryStore + delay time.Duration + onLinkAddressRequest func() +} + +var _ LinkAddressResolver = (*testNeighborResolver)(nil) + +func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(addr) + }) + + // Execute post address resolution action, if available. + if f := r.onLinkAddressRequest; f != nil { + f() + } + return nil +} + +// fakeRequest emulates handling a response for a link address request. +func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { + if entry, ok := r.entries.entryByAddr(addr); ok { + r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } +} + +func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == testEntryBroadcastAddr { + return testEntryBroadcastLinkAddr, true + } + return "", false +} + +func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return 0 +} + +type entryEvent struct { + nicID tcpip.NICID + address tcpip.Address + linkAddr tcpip.LinkAddress + state NeighborState +} + +func TestNeighborCacheGetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheSetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + neigh.setConfig(c) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheEntry(t *testing.T) { + c := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveEntry(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } +} + +type testContext struct { + clock *fakeClock + neigh *neighborCache + store *testEntryStore + linkRes *testNeighborResolver + nudDisp *testNUDDispatcher +} + +func newTestContext(c NUDConfigurations) testContext { + nudDisp := &testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + return testContext{ + clock: clock, + neigh: neigh, + store: store, + linkRes: linkRes, + nudDisp: nudDisp, + } +} + +type overflowOptions struct { + startAtEntryIndex int + wantStaticEntries []NeighborEntry +} + +func (c *testContext) overflowCache(opts overflowOptions) error { + // Fill the neighbor cache to capacity to verify the LRU eviction strategy is + // working properly after the entry removal. + for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + // Add a new entry + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(c.neigh.config().RetransmitTimer) + + var wantEvents []testEntryEventInfo + + // When beyond the full capacity, the cache will evict an entry as per the + // LRU eviction strategy. Note that the number of static entries should not + // affect the total number of dynamic entries that can be added. + if i >= neighborCacheSize+opts.startAtEntryIndex { + removedEntry, ok := c.store.entry(i - neighborCacheSize) + if !ok { + return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + } + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }) + } + + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, testEntryEventInfo{ + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }) + + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the most recent entries. The order of entries reported + // by entries() is undeterministic, so entries have to be sorted before + // comparison. + wantUnsortedEntries := opts.wantStaticEntries + for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + return nil +} + +// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy +// respects the dynamic entry count. +func TestNeighborCacheOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when an entry is removed. +func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(c.neigh.config().RetransmitTimer) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the entry + c.neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that +// adding a duplicate static entry with the same link address does not dispatch +// any events. +func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that +// adding a duplicate static entry with a different link address dispatches a +// change event. +func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a duplicate entry with a different link address + staticLinkAddr += "duplicate" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when a static entry is +// added then removed. In this case, the dynamic entry count shouldn't have +// been touched. +func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.removeEntry(entry.Addr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU +// cache eviction strategy keeps count of the dynamic entry count when an entry +// is overwritten by a static entry. Static entries should not count towards +// the size of the LRU cache. +func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Override the entry with a static one using the same address + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: staticLinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheNotifiesWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + clock.advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + id, ok := s.Fetch(false /* block */) + if !ok { + t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + if id != wakerID { + t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + + // Remove the waker before the neighbor cache has the opportunity to send a + // notification. + neigh.removeWaker(entry.Addr, &w) + clock.advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + if id, ok := s.Fetch(false /* block */); ok { + t.Errorf("unexpected notification from waker with id %d", id) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if err != nil { + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheClear(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add a dynamic entry. + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a static entry. + neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Clear shoud remove both dynamic and static entries. + neigh.clear() + + // Remove events dispatched from clear() have no deterministic order so they + // need to be sorted beforehand. + wantUnsortedEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction +// strategy keeps count of the dynamic entry count when all entries are +// cleared. +func TestNeighborCacheClearThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Clear the cache. + c.neigh.clear() + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + frequentlyUsedEntry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + + // The following logic is very similar to overflowCache, but + // periodically refreshes the frequently used entry. + + // Fill the neighbor cache to capacity + for i := 0; i < neighborCacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Keep adding more entries + for i := neighborCacheSize; i < store.size(); i++ { + // Periodically refresh the frequently used entry + if i%(neighborCacheSize/2) == 0 { + _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + } + } + + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // An entry should have been removed, as per the LRU eviction strategy + removedEntry, ok := store.entry(i - neighborCacheSize + 1) + if !ok { + t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the frequently used entry and the most recent entries. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + wantUnsortedEntries := []NeighborEntry{ + { + Addr: frequentlyUsedEntry.Addr, + LocalAddr: frequentlyUsedEntry.LocalAddr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, + }, + } + + for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheConcurrent(t *testing.T) { + const concurrentProcesses = 16 + + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + storeEntries := store.entries() + for _, entry := range storeEntries { + var wg sync.WaitGroup + for r := 0; r < concurrentProcesses; r++ { + wg.Add(1) + go func(entry NeighborEntry) { + defer wg.Done() + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + } + }(entry) + } + + // Wait for all gorountines to send a request + wg.Wait() + + // Process all the requests for a single entry concurrently + clock.advance(typicalLatency) + } + + // All goroutines add in the same order and add more values than can fit in + // the cache. Our eviction strategy requires that the last entries are + // present, up to the size of the neighbor cache, and the rest are missing. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + var wantUnsortedEntries []NeighborEntry + for i := store.size() - neighborCacheSize; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Errorf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheReplace(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add an entry + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // Verify the entry exists + e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + if doneCh != nil { + t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + } + if t.Failed() { + t.FailNow() + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + } + + // Notify of a link address change + var updatedLinkAddr tcpip.LinkAddress + { + entry, ok := store.entry(1) + if !ok { + t.Fatalf("store.entry(1) not found") + } + updatedLinkAddr = entry.LinkAddr + } + store.set(0, updatedLinkAddr) + neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + + // Requesting the entry again should start address resolution + { + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(config.DelayFirstProbeTime + typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + } + + // Verify the entry's new link address + { + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + clock.advance(typicalLatency) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want = NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: updatedLinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + } +} + +func TestNeighborCacheResolutionFailed(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + + var requestCount uint32 + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + onLinkAddressRequest: func() { + atomic.AddUint32(&requestCount, 1) + }, + } + + // First, sanity check that resolution is working + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + + // Verify that address resolution for an unknown address returns ErrNoLinkAddress + before := atomic.LoadUint32(&requestCount) + + entry.Addr += "2" + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } + + maxAttempts := neigh.config().MaxUnicastProbes + if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } +} + +// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes +// probes and not retrieving a confirmation before the duration defined by +// MaxMulticastProbes * RetransmitTimer. +func TestNeighborCacheResolutionTimeout(t *testing.T) { + config := DefaultNUDConfigurations() + config.RetransmitTimer = time.Millisecond // small enough to cause timeout + + clock := newFakeClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: time.Minute, // large enough to cause timeout + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } +} + +// TestNeighborCacheStaticResolution checks that static link addresses are +// resolved immediately and don't send resolution requests. +func TestNeighborCacheStaticResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + } + want := NeighborEntry{ + Addr: testEntryBroadcastAddr, + LocalAddr: testEntryLocalAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + } +} + +func BenchmarkCacheClear(b *testing.B) { + b.StopTimer() + config := DefaultNUDConfigurations() + clock := &tcpip.StdClock{} + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: 0, + } + + // Clear for every possible size of the cache + for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + // Fill the neighbor cache to capacity. + for i := 0; i < cacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + b.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh != nil { + <-doneCh + } + } + + b.StartTimer() + neigh.clear() + b.StopTimer() + } +} diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go new file mode 100644 index 000000000..0068cacb8 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -0,0 +1,482 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// NeighborEntry describes a neighboring device in the local network. +type NeighborEntry struct { + Addr tcpip.Address + LocalAddr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +// NeighborState defines the state of a NeighborEntry within the Neighbor +// Unreachability Detection state machine, as per RFC 4861 section 7.3.2. +type NeighborState uint8 + +const ( + // Unknown means reachability has not been verified yet. This is the initial + // state of entries that have been created automatically by the Neighbor + // Unreachability Detection state machine. + Unknown NeighborState = iota + // Incomplete means that there is an outstanding request to resolve the + // address. + Incomplete + // Reachable means the path to the neighbor is functioning properly for both + // receive and transmit paths. + Reachable + // Stale means reachability to the neighbor is unknown, but packets are still + // able to be transmitted to the possibly stale link address. + Stale + // Delay means reachability to the neighbor is unknown and pending + // confirmation from an upper-level protocol like TCP, but packets are still + // able to be transmitted to the possibly stale link address. + Delay + // Probe means a reachability confirmation is actively being sought by + // periodically retransmitting reachability probes until a reachability + // confirmation is received, or until the max amount of probes has been sent. + Probe + // Static describes entries that have been explicitly added by the user. They + // do not expire and are not deleted until explicitly removed. + Static + // Failed means traffic should not be sent to this neighbor since attempts of + // reachability have returned inconclusive. + Failed +) + +// neighborEntry implements a neighbor entry's individual node behavior, as per +// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in +// parallel with the sending of packets to a neighbor, necessitating the +// entry's lock to be acquired for all operations. +type neighborEntry struct { + neighborEntryEntry + + nic *NIC + protocol tcpip.NetworkProtocolNumber + + // linkRes provides the functionality to send reachability probes, used in + // Neighbor Unreachability Detection. + linkRes LinkAddressResolver + + // nudState points to the Neighbor Unreachability Detection configuration. + nudState *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + neigh NeighborEntry + + // wakers is a set of waiters for address resolution result. Anytime state + // transitions out of incomplete these waiters are notified. It is nil iff + // address resolution is ongoing and no clients are waiting for the result. + wakers map[*sleep.Waker]struct{} + + // done is used to allow callers to wait on address resolution. It is nil + // iff nudState is not Reachable and address resolution is not yet in + // progress. + done chan struct{} + + isRouter bool + job *tcpip.Job +} + +// newNeighborEntry creates a neighbor cache entry starting at the default +// state, Unknown. Transition out of Unknown by calling either +// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created +// neighborEntry. +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { + return &neighborEntry{ + nic: nic, + linkRes: linkRes, + nudState: nudState, + neigh: NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + State: Unknown, + }, + } +} + +// newStaticNeighborEntry creates a neighbor cache entry starting at the Static +// state. The entry can only transition out of Static by directly calling +// `setStateLocked`. +func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + if nic.stack.nudDisp != nil { + nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + } + return &neighborEntry{ + nic: nic, + nudState: state, + neigh: NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + }, + } +} + +// addWaker adds w to the list of wakers waiting for address resolution. +// Assumes the entry has already been appropriately locked. +func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { + if w == nil { + return + } + if e.wakers == nil { + e.wakers = make(map[*sleep.Waker]struct{}) + } + e.wakers[w] = struct{}{} +} + +// notifyWakersLocked notifies those waiting for address resolution, whether it +// succeeded or failed. Assumes the entry has already been appropriately locked. +func (e *neighborEntry) notifyWakersLocked() { + for w := range e.wakers { + w.Assert() + } + e.wakers = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has +// been added. +func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry +// has changed state or link-layer address. +func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry +// has been removed. +func (e *neighborEntry) dispatchRemoveEventLocked() { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + } +} + +// setStateLocked transitions the entry to the specified state immediately. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +// +// e.mu MUST be locked. +func (e *neighborEntry) setStateLocked(next NeighborState) { + // Cancel the previously scheduled action, if there is one. Entries in + // Unknown, Stale, or Static state do not have scheduled actions. + if timer := e.job; timer != nil { + timer.Cancel() + } + + prev := e.neigh.State + e.neigh.State = next + e.neigh.UpdatedAt = time.Now() + config := e.nudState.Config() + + switch next { + case Incomplete: + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendMulticastProbe() + + case Reachable: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + }) + e.job.Schedule(e.nudState.ReachableTime()) + + case Delay: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Probe) + e.setStateLocked(Probe) + }) + e.job.Schedule(config.DelayFirstProbeTime) + + case Probe: + var retryCounter uint32 + var sendUnicastProbe func() + + sendUnicastProbe = func() { + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendUnicastProbe() + + case Failed: + e.notifyWakersLocked() + e.job = e.nic.stack.newJob(&e.mu, func() { + e.nic.neigh.removeEntryLocked(e) + }) + e.job.Schedule(config.UnreachableTime) + + case Unknown, Stale, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next)) + } +} + +// handlePacketQueuedLocked advances the state machine according to a packet +// being queued for outgoing transmission. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +func (e *neighborEntry) handlePacketQueuedLocked() { + switch e.neigh.State { + case Unknown: + e.dispatchAddEventLocked(Incomplete) + e.setStateLocked(Incomplete) + + case Stale: + e.dispatchChangeEventLocked(Delay) + e.setStateLocked(Delay) + + case Incomplete, Reachable, Delay, Probe, Static, Failed: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or +// Neighbor Solicitation for ARP or NDP, respectively). +// +// Follows the logic defined in RFC 4861 section 7.2.3. +func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { + // Probes MUST be silently discarded if the target address is tentative, does + // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These + // checks MUST be done by the NetworkEndpoint. + + switch e.neigh.State { + case Unknown, Incomplete, Failed: + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchAddEventLocked(Stale) + e.setStateLocked(Stale) + e.notifyWakersLocked() + + case Reachable, Delay, Probe: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + + case Stale: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + } + + case Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleConfirmationLocked processes an incoming neighbor confirmation +// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively). +// +// Follows the state machine defined by RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce Cryptographically +// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the +// claimed source of an NDP message is the owner of the claimed address. +func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + switch e.neigh.State { + case Incomplete: + if len(linkAddr) == 0 { + // "If the link layer has addresses and no Target Link-Layer Address + // option is included, the receiving node SHOULD silently discard the + // received advertisement." - RFC 4861 section 7.2.5 + break + } + + e.neigh.LinkAddr = linkAddr + if flags.Solicited { + e.dispatchChangeEventLocked(Reachable) + e.setStateLocked(Reachable) + } else { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + e.isRouter = flags.IsRouter + e.notifyWakersLocked() + + // "Note that the Override flag is ignored if the entry is in the + // INCOMPLETE state." - RFC 4861 section 7.2.5 + + case Reachable, Stale, Delay, Probe: + sameLinkAddr := e.neigh.LinkAddr == linkAddr + + if !sameLinkAddr { + if !flags.Override { + if e.neigh.State == Reachable { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + break + } + + e.neigh.LinkAddr = linkAddr + + if !flags.Solicited { + if e.neigh.State != Stale { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } else { + // Notify the LinkAddr change, even though NUD state hasn't changed. + e.dispatchChangeEventLocked(e.neigh.State) + } + break + } + } + + if flags.Solicited && (flags.Override || sameLinkAddr) { + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + } + // Set state to Reachable again to refresh timers. + e.setStateLocked(Reachable) + e.notifyWakersLocked() + } + + if e.isRouter && !flags.IsRouter { + // "In those cases where the IsRouter flag changes from TRUE to FALSE as + // a result of this update, the node MUST remove that router from the + // Default Router List and update the Destination Cache entries for all + // destinations using that neighbor as a router as specified in Section + // 7.3.3. This is needed to detect when a node that is used as a router + // stops forwarding packets due to being configured as a host." + // - RFC 4861 section 7.2.5 + e.nic.mu.Lock() + e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr) + e.nic.mu.Unlock() + } + e.isRouter = flags.IsRouter + + case Unknown, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol +// (e.g. TCP acknowledgements) reachability confirmation. +func (e *neighborEntry) handleUpperLevelConfirmationLocked() { + switch e.neigh.State { + case Reachable, Stale, Delay, Probe: + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + // Set state to Reachable again to refresh timers. + } + e.setStateLocked(Reachable) + + case Unknown, Incomplete, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go new file mode 100644 index 000000000..b769fb2fa --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -0,0 +1,2870 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "math" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + + entryTestNICID tcpip.NICID = 1 + entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") + entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") + + // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match the + // MTU of loopback interfaces on Linux systems. + entryTestNetDefaultMTU = 65536 +) + +// eventDiffOpts are the options passed to cmp.Diff to compare entry events. +// The UpdatedAt field is ignored due to a lack of a deterministic method to +// predict the time that an event will be dispatched. +func eventDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + } +} + +// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to +// sort slices of events for cases where ordering must be ignored. +func eventDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +// The following unit tests exercise every state transition and verify its +// behavior with RFC 4681. +// +// | From | To | Cause | Action | Event | +// | ========== | ========== | ========================================== | =============== | ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | +// | Unknown | Stale | Probe w/ unknown address | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | +// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | +// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | +// | Reachable | Stale | Reachable timer expired | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | +// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | +// | Stale | Delay | Packet sent | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | Changed | +// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | Changed | +// | Delay | Probe | Delay timer expired | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | Changed | +// | Probe | Probe | Retransmit timer expired | Send probe | Changed | +// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Failed | | Unreachability timer expired | Delete entry | | + +type testEntryEventType uint8 + +const ( + entryTestAdded testEntryEventType = iota + entryTestChanged + entryTestRemoved +) + +func (t testEntryEventType) String() string { + switch t { + case entryTestAdded: + return "add" + case entryTestChanged: + return "change" + case entryTestRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +// Fields are exported for use with cmp.Diff. +type testEntryEventInfo struct { + EventType testEntryEventType + NICID tcpip.NICID + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +func (e testEntryEventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) +} + +// testNUDDispatcher implements NUDDispatcher to validate the dispatching of +// events upon certain NUD state machine events. +type testNUDDispatcher struct { + mu sync.Mutex + events []testEntryEventInfo +} + +var _ NUDDispatcher = (*testNUDDispatcher)(nil) + +func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { + d.mu.Lock() + defer d.mu.Unlock() + d.events = append(d.events, e) +} + +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestAdded, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestChanged, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +type entryTestLinkResolver struct { + mu sync.Mutex + probes []entryTestProbeInfo +} + +var _ LinkAddressResolver = (*entryTestLinkResolver)(nil) + +type entryTestProbeInfo struct { + RemoteAddress tcpip.Address + RemoteLinkAddress tcpip.LinkAddress + LocalAddress tcpip.Address +} + +func (p entryTestProbeInfo) String() string { + return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress) +} + +// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts +// to the local network if linkAddr is the zero value. +func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + p := entryTestProbeInfo{ + RemoteAddress: addr, + RemoteLinkAddress: linkAddr, + LocalAddress: localAddr, + } + r.mu.Lock() + defer r.mu.Unlock() + r.probes = append(r.probes, p) + return nil +} + +// ResolveStaticAddress attempts to resolve address without sending requests. +// It either resolves the name immediately or returns the empty LinkAddress. +func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + return "", false +} + +// LinkAddressProtocol returns the network protocol of the addresses this +// resolver can resolve. +func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return entryTestNetNumber +} + +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) { + clock := newFakeClock() + disp := testNUDDispatcher{} + nic := NIC{ + id: entryTestNICID, + linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + stack: &Stack{ + clock: clock, + nudDisp: &disp, + }, + } + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + nudState := NewNUDState(c, rng) + linkRes := entryTestLinkResolver{} + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) + + // Stub out ndpState to verify modification of default routers. + nic.mu.ndp = ndpState{ + nic: &nic, + defaultRouters: make(map[tcpip.Address]defaultRouterState), + } + + // Stub out the neighbor cache to verify deletion from the cache. + nic.neigh = &neighborCache{ + nic: &nic, + state: nudState, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + nic.neigh.cache[entryTestAddr1] = entry + + return entry, &disp, &linkRes, clock +} + +// TestEntryInitiallyUnknown verifies that the state of a newly created +// neighborEntry is Unknown. +func TestEntryInitiallyUnknown(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(time.Hour) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToIncomplete(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + { + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +func TestEntryUnknownToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + updatedAt := e.neigh.UpdatedAt + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // UpdatedAt should remain the same during address resolution. + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // UpdatedAt should change after failing address resolution. Timing out after + // sending the last probe transitions the entry to Failed. + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + clock.advance(c.RetransmitTimer) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) + } + e.mu.Unlock() +} + +func TestEntryIncompleteToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +// TestEntryAddsAndClearsWakers verifies that wakers are added when +// addWakerLocked is called and cleared when address resolution finishes. In +// this case, address resolution will finish when transitioning from Incomplete +// to Reachable. +func TestEntryAddsAndClearsWakers(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + w := sleep.Waker{} + s := sleep.Sleeper{} + s.AddWaker(&w, 123) + defer s.Done() + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got := e.wakers; got != nil { + t.Errorf("got e.wakers = %v, want = nil", got) + } + e.addWakerLocked(&w) + if got, want := w.IsAsserted(), false; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + if e.wakers == nil { + t.Error("expected e.wakers to be non-nil") + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.wakers != nil { + t.Errorf("got e.wakers = %v, want = nil", e.wakers) + } + if got, want := w.IsAsserted(), true; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + linkRes.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +type testLocker struct{} + +var _ sync.Locker = (*testLocker)(nil) + +func (*testLocker) Lock() {} +func (*testLocker) Unlock() {} + +func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{ + invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}), + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.isRouter, false; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok { + t.Errorf("unexpected defaultRouter for %s", entryTestAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryReachableToStaleWhenTimeout(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToDelay(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleUpperLevelConfirmationLocked() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 1 + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToProbe(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario: +// 1. Probe is received +// 2. Entry is created in Stale +// 3. Packet is queued on the entry +// 4. Entry transitions to Delay then Probe +// 5. Probe is sent +func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Probe to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // Probe caused by the Delay-to-Probe transition + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryFailedGetsDeleted(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // Verify the cache contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { + t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + } + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + // Verify the cache no longer contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { + t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) + } +} diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go new file mode 100644 index 000000000..aa7311ec6 --- /dev/null +++ b/pkg/tcpip/stack/neighborstate_string.go @@ -0,0 +1,44 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type NeighborState"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Unknown-0] + _ = x[Incomplete-1] + _ = x[Reachable-2] + _ = x[Stale-3] + _ = x[Delay-4] + _ = x[Probe-5] + _ = x[Static-6] + _ = x[Failed-7] +} + +const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed" + +var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53} + +func (i NeighborState) String() string { + if i >= NeighborState(len(_NeighborState_index)-1) { + return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]] +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ab6798aa6..e74d2562a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -15,48 +15,66 @@ package stack import ( - "strings" - "sync" + "fmt" + "math/rand" + "reflect" + "sort" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) +var ipv4BroadcastAddr = tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 8 * header.IPv4AddressSize, + }, +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint - loopback bool - - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint - - stats NICStats - - // ndp is the NDP related state for NIC. - // - // Note, read and write operations on ndp require that the NIC is - // appropriately locked. - ndp ndpState + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint + context NICContext + + stats NICStats + neigh *neighborCache + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + + mu struct { + sync.RWMutex + enabled bool + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + mcastJoins map[NetworkEndpointID]uint32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + ndp ndpState + } } // NICStats includes transmitted and received stats. type NICStats struct { Tx DirectionStats Rx DirectionStats + + DisabledRx DirectionStats +} + +func makeNICStats() NICStats { + var s NICStats + tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) + return s } // DirectionStats includes packet and byte counts. @@ -85,59 +103,170 @@ const ( ) // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. + + // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints + // observe an MTU of at least 1280 bytes. Ensure that this requirement + // of IPv6 is supported on this endpoint's LinkEndpoint. + nic := &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - loopback: loopback, - primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), - endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), - mcastJoins: make(map[NetworkEndpointID]int32), - packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), - stats: NICStats{ - Tx: DirectionStats{ - Packets: &tcpip.StatCounter{}, - Bytes: &tcpip.StatCounter{}, - }, - Rx: DirectionStats{ - Packets: &tcpip.StatCounter{}, - Bytes: &tcpip.StatCounter{}, - }, - }, - ndp: ndpState{ - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - }, - } - nic.ndp.nic = nic + stack: stack, + id: id, + name: name, + linkEP: ep, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + } + nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) + nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) + nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.ndp = ndpState{ + nic: nic, + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), + } + nic.mu.ndp.initializeTempAddrState() // Register supported packet endpoint protocols. for _, netProto := range header.Ethertypes { - nic.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = []PacketEndpoint{} } for _, netProto := range stack.networkProtocols { - nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + netNum := netProto.Number() + nic.mu.packetEPs[netNum] = nil + nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic, ep, stack) + } + + // Check for Neighbor Unreachability Detection support. + if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 { + rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) + nic.neigh = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } } + nic.linkEP.Attach(nic) + return nic } -// enable enables the NIC. enable will attach the link to its LinkEndpoint and -// join the IPv6 All-Nodes Multicast address (ff02::1). +// enabled returns true if n is enabled. +func (n *NIC) enabled() bool { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + return enabled +} + +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() *tcpip.Error { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + if !enabled { + return nil + } + + n.mu.Lock() + err := n.disableLocked() + n.mu.Unlock() + return err +} + +// disableLocked disables n. +// +// It undoes the work done by enable. +// +// n MUST be locked. +func (n *NIC) disableLocked() *tcpip.Error { + if !n.mu.enabled { + return nil + } + + // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be + // invalidated? Currently, Routes will continue to work when a NIC is enabled + // again, and applications may not know that the underlying NIC was ever + // disabled. + + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + n.mu.ndp.stopSolicitingRouters() + n.mu.ndp.cleanupState(false /* hostOnly */) + + // Stop DAD for all the unicast IPv6 endpoints that are in the + // permanentTentative state. + for _, r := range n.mu.endpoints { + if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { + n.mu.ndp.stopDuplicateAddressDetection(addr) + } + } + + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + // The address may have already been removed. + if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + n.mu.enabled = false + return nil +} + +// enable enables n. +// +// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast +// address (ff02::1), start DAD for permanent addresses, and start soliciting +// routers if the stack is not operating as a router. If the stack is also +// configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { - n.attachLinkEndpoint() + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + if enabled { + return nil + } + + n.mu.Lock() + defer n.mu.Unlock() + + if n.mu.enabled { + return nil + } + + n.mu.enabled = true // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if err := n.AddAddress(tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, - }, NeverPrimaryEndpoint); err != nil { + if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { + return err + } + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil { return err } } @@ -159,77 +288,298 @@ func (n *NIC) enable() *tcpip.Error { return nil } - n.mu.Lock() - defer n.mu.Unlock() - + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { return err } - if !n.stack.autoGenIPv6LinkLocal { - return nil + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent + // state. + // + // Addresses may have aleady completed DAD but in the time since the NIC was + // last enabled, other devices may have acquired the same addresses. + for _, r := range n.mu.endpoints { + addr := r.address() + if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { + continue + } + + r.setKind(permanentTentative) + if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil { + return err + } } - l2addr := n.linkEP.LinkAddress() + // Do not auto-generate an IPv6 link-local address for loopback devices. + if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { + // The valid and preferred lifetime is infinite for the auto-generated + // link-local address. + n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) + } - // Only attempt to generate the link-local address if we have a - // valid MAC address. + // If we are operating as a router, then do not solicit routers since we + // won't process the RAs anyways. // - // TODO(b/141011931): Validate a LinkEndpoint's link address - // (provided by LinkEndpoint.LinkAddress) before reaching this - // point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil + // Routers do not process Router Advertisements (RA) the same way a host + // does. That is, routers do not learn from RAs (e.g. on-link prefixes + // and default routers). Therefore, soliciting RAs from other routers on + // a link is unnecessary for routers. + if !n.stack.Forwarding(header.IPv6ProtocolNumber) { + n.mu.ndp.startSolicitingRouters() } - addr := header.LinkLocalAddr(l2addr) + return nil +} - _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - }, - }, CanBePrimaryEndpoint) +// remove detaches NIC from the link endpoint, and marks existing referenced +// network endpoints expired. This guarantees no packets between this NIC and +// the network stack. +func (n *NIC) remove() *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + n.disableLocked() + + // TODO(b/151378115): come up with a better way to pick an error than the + // first one. + var err *tcpip.Error + + // Forcefully leave multicast groups. + for nid := range n.mu.mcastJoins { + if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { + err = tempErr + } + } + + // Remove permanent and permanentTentative addresses, so no packet goes out. + for nid, ref := range n.mu.endpoints { + switch ref.getKind() { + case permanentTentative, permanent: + if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { + err = tempErr + } + } + } + + // Release any resources the network endpoint may hold. + for _, ep := range n.networkEndpoints { + ep.Close() + } + + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) return err } -// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it -// to start delivering packets. -func (n *NIC) attachLinkEndpoint() { - n.linkEP.Attach(n) +// becomeIPv6Router transitions n into an IPv6 router. +// +// When transitioning into an IPv6 router, host-only state (NDP discovered +// routers, discovered on-link prefixes, and auto-generated addresses) will +// be cleaned up/invalidated and NDP router solicitations will be stopped. +func (n *NIC) becomeIPv6Router() { + n.mu.Lock() + defer n.mu.Unlock() + + n.mu.ndp.cleanupState(true /* hostOnly */) + n.mu.ndp.stopSolicitingRouters() +} + +// becomeIPv6Host transitions n into an IPv6 host. +// +// When transitioning into an IPv6 host, NDP router solicitations will be +// started. +func (n *NIC) becomeIPv6Host() { + n.mu.Lock() + defer n.mu.Unlock() + + n.mu.ndp.startSolicitingRouters() } // setPromiscuousMode enables or disables promiscuous mode. func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Lock() - n.promiscuous = enable + n.mu.promiscuous = enable n.mu.Unlock() } func (n *NIC) isPromiscuousMode() bool { n.mu.RLock() - rv := n.promiscuous + rv := n.mu.promiscuous n.mu.RUnlock() return rv } +func (n *NIC) isLoopback() bool { + return n.linkEP.Capabilities()&CapabilityLoopback != 0 +} + // setSpoofing enables or disables address spoofing. func (n *NIC) setSpoofing(enable bool) { n.mu.Lock() - n.spoofing = enable + n.mu.spoofing = enable n.mu.Unlock() } -// primaryEndpoint returns the primary endpoint of n for the given network -// protocol. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { +// primaryEndpoint will return the first non-deprecated endpoint if such an +// endpoint exists for the given protocol and remoteAddr. If no non-deprecated +// endpoint exists, the first deprecated endpoint will be returned. +// +// If an IPv6 primary endpoint is requested, Source Address Selection (as +// defined by RFC 6724 section 5) will be performed. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { + if protocol == header.IPv6ProtocolNumber && remoteAddr != "" { + return n.primaryIPv6Endpoint(remoteAddr) + } + n.mu.RLock() defer n.mu.RUnlock() - for _, r := range n.primary[protocol] { - if r.isValidForOutgoing() && r.tryIncRef() { + var deprecatedEndpoint *referencedNetworkEndpoint + for _, r := range n.mu.primary[protocol] { + if !r.isValidForOutgoingRLocked() { + continue + } + + if !r.deprecated { + if r.tryIncRef() { + // r is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + deprecatedEndpoint.decRefLocked() + deprecatedEndpoint = nil + } + + return r + } + } else if deprecatedEndpoint == nil && r.tryIncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of r in + // case n doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, r's reference count + // will be decremented when such an endpoint is found. + deprecatedEndpoint = r + } + } + + // n doesn't have any valid non-deprecated endpoints, so return + // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated + // endpoints either). + return deprecatedEndpoint +} + +// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC +// 6724 section 5). +type ipv6AddrCandidate struct { + ref *referencedNetworkEndpoint + scope header.IPv6AddressScope +} + +// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 and 7 are followed. +// +// remoteAddr must be a valid IPv6 address. +func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { + n.mu.RLock() + ref := n.primaryIPv6EndpointRLocked(remoteAddr) + n.mu.RUnlock() + return ref +} + +// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 and 7 are followed. +// +// remoteAddr must be a valid IPv6 address. +// +// n.mu MUST be read locked. +func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { + primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] + + if len(primaryAddrs) == 0 { + return nil + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) + for _, r := range primaryAddrs { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !r.isValidForOutgoingRLocked() { + continue + } + + addr := r.address() + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) + } + + cs = append(cs, ipv6AddrCandidate{ + ref: r, + scope: scope, + }) + } + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.ref.address() == remoteAddr { + return true + } + if sb.ref.address() == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. + if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp { + return saTemp + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if r := c.ref; r.tryIncRef() { return r } } @@ -237,62 +587,87 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +// hasPermanentAddrLocked returns true if n has a permanent (including currently +// tentative) address, addr. +func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] + + if !ok { + return false + } + + kind := ref.getKind() + + return kind == permanent || kind == permanentTentative +} + +type getRefBehaviour int + +const ( + // spoofing indicates that the NIC's spoofing flag should be observed when + // getting a NIC's referenced network endpoint. + spoofing getRefBehaviour = iota + + // promiscuous indicates that the NIC's promiscuous flag should be observed + // when getting a NIC's referenced network endpoint. + promiscuous +) + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) + return n.getRefOrCreateTemp(protocol, address, peb, spoofing) } // getRefEpOrCreateTemp returns the referenced network endpoint for the given -// protocol and address. If none exists a temporary one may be created if -// we are in promiscuous mode or spoofing. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { - id := NetworkEndpointID{address} - +// protocol and address. +// +// If none exists a temporary one may be created if we are in promiscuous mode +// or spoofing. Promiscuous mode will only be checked if promiscuous is true. +// Similarly, spoofing will only be checked if spoofing is true. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { n.mu.RLock() - if ref, ok := n.endpoints[id]; ok { + var spoofingOrPromiscuous bool + switch tempRef { + case spoofing: + spoofingOrPromiscuous = n.mu.spoofing + case promiscuous: + spoofingOrPromiscuous = n.mu.promiscuous + } + + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // An endpoint with this id exists, check if it can be used and return it. - switch ref.getKind() { - case permanentExpired: - if !spoofingOrPromiscuous { - n.mu.RUnlock() - return nil - } - fallthrough - case temporary, permanent: - if ref.tryIncRef() { - n.mu.RUnlock() - return ref - } + if !ref.isAssignedRLocked(spoofingOrPromiscuous) { + n.mu.RUnlock() + return nil + } + + if ref.tryIncRef() { + n.mu.RUnlock() + return ref } } - // A usable reference was not found, create a temporary one if requested by - // the caller or if the address is found in the NIC's subnets. - createTempEP := spoofingOrPromiscuous - if !createTempEP { - for _, sn := range n.addressRanges { - // Skip the subnet address. - if address == sn.ID() { - continue - } - // For now just skip the broadcast address, until we support it. - // FIXME(b/137608825): Add support for sending/receiving directed - // (subnet) broadcast. - if address == sn.Broadcast() { - continue - } - if sn.Contains(address) { - createTempEP = true - break - } + // Check if address is a broadcast address for the endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + n.mu.RUnlock() + return ref } } + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous n.mu.RUnlock() if !createTempEP { @@ -303,11 +678,44 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.endpoints[id]; ok { + ref := n.getRefOrCreateTempLocked(protocol, address, peb) + n.mu.Unlock() + return ref +} + +// getRefForBroadcastLocked returns an endpoint where address is the IPv4 +// broadcast address for the endpoint's network. +// +// n.mu MUST be read locked. +func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { + for _, ref := range n.mu.endpoints { + // Only IPv4 has a notion of broadcast addresses. + if ref.protocol != header.IPv4ProtocolNumber { + continue + } + + addr := ref.addrWithPrefix() + subnet := addr.Subnet() + if subnet.IsBroadcast(address) && ref.tryIncRef() { + return ref + } + } + + return nil +} + +/// getRefOrCreateTempLocked returns an existing endpoint for address or creates +/// and returns a temporary endpoint. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. +// +// n.mu must be write locked. +func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { - n.mu.Unlock() return ref } // tryIncRef failing means the endpoint is scheduled to be removed once the @@ -316,10 +724,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t n.removeEndpointLocked(ref) } + // Check if address is a broadcast address for an endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + return ref + } + } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { - n.mu.Unlock() return nil } ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ @@ -328,26 +744,40 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary) - - n.mu.Unlock() + }, peb, temporary, static, false) return ref } -func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { - id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.endpoints[id]; ok { +// addAddressLocked adds a new protocolAddress to n. +// +// If n already has the address in a non-permanent state, and the kind given is +// permanent, that address will be promoted in place and its properties set to +// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress. +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP addresses before adding them. + + // Sanity check. + id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} + if ref, ok := n.mu.endpoints[id]; ok { + // Endpoint already exists. + if kind != permanent { + return nil, tcpip.ErrDuplicateAddress + } switch ref.getKind() { case permanentTentative, permanent: // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: - // Promote the endpoint to become permanent and respect - // the new peb. + // Promote the endpoint to become permanent and respect the new peb, + // configType and deprecated status. if ref.tryIncRef() { + // TODO(b/147748385): Perform Duplicate Address Detection when promoting + // an IPv6 endpoint to permanent. ref.setKind(permanent) + ref.deprecated = deprecated + ref.configType = configType - refs := n.primary[ref.protocol] + refs := n.mu.primary[ref.protocol] for i, r := range refs { if r == ref { switch peb { @@ -357,9 +787,9 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p if i == 0 { return ref, nil } - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) case NeverPrimaryEndpoint: - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) return ref, nil } } @@ -377,44 +807,30 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent) -} - -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { - // TODO(b/141022673): Validate IP address before adding them. - - // Sanity check. - id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} - if _, ok := n.endpoints[id]; ok { - // Endpoint already exists. - return nil, tcpip.ErrDuplicateAddress - } - - netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] + ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol } - // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP) - if err != nil { - return nil, err - } - isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) // If the address is an IPv6 address and it is a permanent address, - // mark it as tentative so it goes through the DAD process. + // mark it as tentative so it goes through the DAD process if the NIC is + // enabled. If the NIC is not enabled, DAD will be started when the NIC is + // enabled. if isIPv6Unicast && kind == permanent { kind = permanentTentative } ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, + refs: 1, + addr: protocolAddress.AddressWithPrefix, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, + configType: configType, + deprecated: deprecated, } // Set up cache if link address resolution exists for this protocol. @@ -433,13 +849,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } - n.endpoints[id] = ref + n.mu.endpoints[id] = ref n.insertPrimaryEndpointLocked(ref, peb) - // If we are adding a tentative IPv6 address, start DAD. - if isIPv6Unicast && kind == permanentTentative { - if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { + // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled. + if isIPv6Unicast && kind == permanentTentative && n.mu.enabled { + if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } } @@ -452,7 +868,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addPermanentAddressLocked(protocolAddress, peb) + _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */) n.mu.Unlock() return err @@ -464,22 +880,18 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() - addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) - for nid, ref := range n.endpoints { + addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) + for _, ref := range n.mu.endpoints { // Don't include tentative, expired or temporary endpoints to // avoid confusion and prevent the caller from using those. switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - // TODO(b/140898488): Should tentative addresses be - // returned? + case permanentExpired, temporary: continue } + addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: ref.protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: nid.LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }, + Protocol: ref.protocol, + AddressWithPrefix: ref.addrWithPrefix(), }) } return addrs @@ -491,7 +903,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() var addrs []tcpip.ProtocolAddress - for proto, list := range n.primary { + for proto, list := range n.mu.primary { for _, ref := range list { // Don't include tentative, expired or tempory endpoints // to avoid confusion and prevent the caller from using @@ -502,59 +914,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { } addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: proto, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }, + Protocol: proto, + AddressWithPrefix: ref.addrWithPrefix(), }) } } return addrs } -// AddAddressRange adds a range of addresses to n, so that it starts accepting -// packets targeted at the given addresses and network protocol. The range is -// given by a subnet address, and all addresses contained in the subnet are -// used except for the subnet address itself and the subnet's broadcast -// address. -func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { - n.mu.Lock() - n.addressRanges = append(n.addressRanges, subnet) - n.mu.Unlock() -} +// primaryAddress returns the primary address associated with this NIC. +// +// primaryAddress will return the first non-deprecated address if such an +// address exists. If no non-deprecated address exists, the first deprecated +// address will be returned. +func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { + n.mu.RLock() + defer n.mu.RUnlock() -// RemoveAddressRange removes the given address range from n. -func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { - n.mu.Lock() + list, ok := n.mu.primary[proto] + if !ok { + return tcpip.AddressWithPrefix{} + } - // Use the same underlying array. - tmp := n.addressRanges[:0] - for _, sub := range n.addressRanges { - if sub != subnet { - tmp = append(tmp, sub) + var deprecatedEndpoint *referencedNetworkEndpoint + for _, ref := range list { + // Don't include tentative, expired or tempory endpoints to avoid confusion + // and prevent the caller from using those. + switch ref.getKind() { + case permanentTentative, permanentExpired, temporary: + continue } - } - n.addressRanges = tmp - n.mu.Unlock() -} + if !ref.deprecated { + return ref.addrWithPrefix() + } -// Subnets returns the Subnets associated with this NIC. -func (n *NIC) AddressRanges() []tcpip.Subnet { - n.mu.RLock() - defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) - for nid := range n.endpoints { - sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) - if err != nil { - // This should never happen as the mask has been carefully crafted to - // match the address. - panic("Invalid endpoint subnet: " + err.Error()) + if deprecatedEndpoint == nil { + deprecatedEndpoint = ref } - sns = append(sns, sn) } - return append(sns, n.addressRanges...) + + if deprecatedEndpoint != nil { + return deprecatedEndpoint.addrWithPrefix() + } + + return tcpip.AddressWithPrefix{} } // insertPrimaryEndpointLocked adds r to n's primary endpoint list as required @@ -564,21 +968,21 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { switch peb { case CanBePrimaryEndpoint: - n.primary[r.protocol] = append(n.primary[r.protocol], r) + n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) case FirstPrimaryEndpoint: - n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...) + n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) } } func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { - id := *r.ep.ID() + id := NetworkEndpointID{LocalAddress: r.address()} // Nothing to do if the reference has already been replaced with a different // one. This happens in the case where 1) this endpoint's ref count hit zero // and was waiting (on the lock) to be removed and 2) the same address was // re-added in the meantime by removing this endpoint from the list and // adding a new one. - if n.endpoints[id] != r { + if n.mu.endpoints[id] != r { return } @@ -586,16 +990,15 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { panic("Reference count dropped to zero before being removed") } - delete(n.endpoints, id) - refs := n.primary[r.protocol] + delete(n.mu.endpoints, id) + refs := n.mu.primary[r.protocol] for i, ref := range refs { if ref == r { - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + refs[len(refs)-1] = nil break } } - - r.ep.Close() } func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { @@ -605,7 +1008,7 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.endpoints[NetworkEndpointID{addr}] + r, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return tcpip.ErrBadLocalAddress } @@ -615,26 +1018,45 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress } - isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) - - // If we are removing a tentative IPv6 unicast address, stop DAD. - if isIPv6Unicast && kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + switch r.protocol { + case header.IPv6ProtocolNumber: + return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) + default: + r.expireLocked() + return nil } +} - r.setKind(permanentExpired) - if !r.decRefLocked() { - // The endpoint still has references to it. - return nil +func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { + addr := r.addrWithPrefix() + + isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) + + if isIPv6Unicast { + n.mu.ndp.stopDuplicateAddressDetection(addr.Address) + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + switch r.configType { + case slaac: + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + case slaacTemp: + n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + } } + r.expireLocked() + // At this point the endpoint is deleted. // If we are removing an IPv6 unicast address, leave the solicited-node // multicast address. + // + // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node + // multicast group may be left by user action. if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr) - if err := n.leaveGroupLocked(snmc); err != nil { + snmc := header.SolicitedNodeAddr(addr.Address) + if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } } @@ -668,23 +1090,23 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A // outlined in RFC 3810 section 5. id := NetworkEndpointID{addr} - joins := n.mcastJoins[id] + joins := n.mu.mcastJoins[id] if joins == 0 { netProto, ok := n.stack.networkProtocols[protocol] if !ok { return tcpip.ErrUnknownProtocol } - if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint); err != nil { + }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } - n.mcastJoins[id] = joins + 1 + n.mu.mcastJoins[id] = joins + 1 return nil } @@ -694,48 +1116,75 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.leaveGroupLocked(addr) + return n.leaveGroupLocked(addr, false /* force */) } // leaveGroupLocked decrements the count for the given multicast address, and // when it reaches zero removes the endpoint for this address. n MUST be locked // before leaveGroupLocked is called. -func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +// +// If force is true, then the count for the multicast addres is ignored and the +// endpoint will be removed immediately. +func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { id := NetworkEndpointID{addr} - joins := n.mcastJoins[id] - switch joins { - case 0: + joins, ok := n.mu.mcastJoins[id] + if !ok { // There are no joins with this address on this NIC. return tcpip.ErrBadLocalAddress - case 1: - // This is the last one, clean up. - if err := n.removePermanentAddressLocked(addr); err != nil { - return err - } } - n.mcastJoins[id] = joins - 1 + + joins-- + if force || joins == 0 { + // There are no outstanding joins or we are forced to leave, clean up. + delete(n.mu.mcastJoins, id) + return n.removePermanentAddressLocked(addr) + } + + n.mu.mcastJoins[id] = joins return nil } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { +// isInGroup returns true if n has joined the multicast group addr. +func (n *NIC) isInGroup(addr tcpip.Address) bool { + n.mu.RLock() + joins := n.mu.mcastJoins[NetworkEndpointID{addr}] + n.mu.RUnlock() + + return joins != 0 +} + +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - ref.ep.HandlePacket(&r, vv) + + ref.ep.HandlePacket(&r, pkt) ref.decRef() } // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when -// the NIC receives a packet from the physical interface. +// the NIC receives a packet from the link endpoint. // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + enabled := n.mu.enabled + // If the NIC is not yet enabled, don't receive any packets. + if !enabled { + n.mu.RUnlock() + + n.stats.DisabledRx.Packets.Increment() + n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + return + } + n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { + n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() return } @@ -747,32 +1196,59 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } // Are any packet sockets listening for this network protocol? - n.mu.RLock() - packetEPs := n.packetEPs[protocol] - // Check whether there are packet sockets listening for every protocol. - // If we received a packet with protocol EthernetProtocolAll, then the - // previous for loop will have handled it. - if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) - } + packetEPs := n.mu.packetEPs[protocol] + // Add any other packet sockets that maybe listening for all protocols. + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, vv, linkHeader) + p := pkt.Clone() + p.PktType = tcpip.PacketHost + ep.HandlePacket(n.id, local, protocol, p) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } - if len(vv.First()) < netProto.MinimumPacketSize() { + // Parse headers. + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) + if !ok { + // The packet is too small to contain a network header. n.stack.stats.MalformedRcvdPackets.Increment() return } + if hasTransportHdr { + // Parse the transport header if present. + if state, ok := n.stack.transportProtocols[transProtoNum]; ok { + state.proto.Parse(pkt) + } + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) + + if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { + // The source address is one of our own, so we never should have gotten a + // packet like this unless handleLocal is false. Loopback also calls this + // function even though the packets didn't come from the physical interface + // so don't drop those. + n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() + return + } - src, dst := netProto.ParseAddresses(vv.First()) + // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. + // Loopback traffic skips the prerouting chain. + if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { + // iptables filtering. + ipt := n.stack.IPTables() + address := n.primaryAddress(protocol) + if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { + // iptables is telling us to drop the packet. + return + } + } if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) + handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) return } @@ -783,54 +1259,98 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { - n.stack.stats.IP.InvalidAddressesReceived.Increment() + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() return } - defer r.Release() - - r.LocalLinkAddress = n.linkEP.LinkAddress() - r.RemoteLinkAddress = remote // Found a NIC. n := r.ref.nic n.mu.RLock() - ref, ok := n.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() + ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() n.mu.RUnlock() if ok { + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() - } else { - // n doesn't have a destination endpoint. - // Send the packet out of n. - // If we want to send the packet to a link-layer, - // we have to reserve space for an Ethernet header. - hdr := buffer.NewPrependableFromView(vv.First(), int(n.linkEP.MaxHeaderLength())) - vv.RemoveFirst() - - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. - // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size())) + r.Release() + return + } + + // n doesn't have a destination endpoint. + // Send the packet out of n. + // TODO(b/128629022): move this logic to route.WritePacket. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) + // forwarder will release route. + return } + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() + r.Release() + return } + + // The link-address resolution finished immediately. + n.forwardPacket(&r, protocol, pkt) + r.Release() return } // If a packet socket handled the packet, don't treat it as invalid. if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidAddressesReceived.Increment() + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() + } +} + +// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. +func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + // We do not deliver to protocol specific packet endpoints as on Linux + // only ETH_P_ALL endpoints get outbound packets. + // Add any other packet sockets that maybe listening for all protocols. + packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + n.mu.RUnlock() + for _, ep := range packetEPs { + p := pkt.Clone() + p.PktType = tcpip.PacketOutgoing + // Add the link layer header as outgoing packets are intercepted + // before the link layer header is created. + n.linkEP.AddHeader(local, remote, protocol, p) + ep.HandlePacket(n.id, local, protocol, p) } } +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() + + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) +} + // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -842,41 +1362,60 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) + n.stack.demux.deliverRawPacket(r, protocol, pkt) + + // TransportHeader is empty only when pkt is an ICMP packet or was reassembled + // from fragments. + if pkt.TransportHeader().View().IsEmpty() { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + // ICMP packets may be longer, but until icmp.Parse is implemented, here + // we parse it using the minimum size. + if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + } else { + // This is either a bad packet or was re-assembled from fragments. + transProto.Parse(pkt) + } + } - if len(vv.First()) < transProto.MinimumPacketSize() { + if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, pkt, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, netHeader, vv) { + if state.defaultHandler(r, id, pkt) { return } } // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { n.stack.stats.MalformedRcvdPackets.Increment() } } // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -887,17 +1426,18 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(vv.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { return } } @@ -907,18 +1447,31 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Name returns the name of n. +func (n *NIC) Name() string { + return n.name +} + // Stack returns the instance of the Stack that owns this NIC. func (n *NIC) Stack() *Stack { return n.stack } +// LinkEndpoint returns the link endpoint of n. +func (n *NIC) LinkEndpoint() LinkEndpoint { + return n.linkEP +} + // isAddrTentative returns true if addr is tentative on n. // // Note that if addr is not associated with n, then this function will return // false. It will only return true if the address is associated with the NIC // AND it is tentative. func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - ref, ok := n.endpoints[NetworkEndpointID{addr}] + n.mu.RLock() + defer n.mu.RUnlock() + + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return false } @@ -926,15 +1479,17 @@ func (n *NIC) isAddrTentative(addr tcpip.Address) bool { return ref.getKind() == permanentTentative } -// dupTentativeAddrDetected attempts to inform n that a tentative addr -// is a duplicate on a link. +// dupTentativeAddrDetected attempts to inform n that a tentative addr is a +// duplicate on a link. // -// dupTentativeAddrDetected will delete the tentative address if it exists. +// dupTentativeAddrDetected will remove the tentative address if it exists. If +// the address was generated via SLAAC, an attempt will be made to generate a +// new address. func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return tcpip.ErrBadAddress } @@ -943,7 +1498,24 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - return n.removePermanentAddressLocked(addr) + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a + // new address will be generated for it. + if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil { + return err + } + + prefix := ref.addrWithPrefix().Subnet() + + switch ref.configType { + case slaac: + n.mu.ndp.regenerateSLAACAddr(prefix) + case slaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + } + + return nil } // setNDPConfigs sets the NDP configurations for n. @@ -954,10 +1526,39 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) { c.validate() n.mu.Lock() - n.ndp.configs = c + n.mu.ndp.configs = c n.mu.Unlock() } +// NUDConfigs gets the NUD configurations for n. +func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) { + if n.neigh == nil { + return NUDConfigurations{}, tcpip.ErrNotSupported + } + return n.neigh.config(), nil +} + +// setNUDConfigs sets the NUD configurations for n. +// +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported + } + c.resetInvalidFields() + n.neigh.setConfig(c) + return nil +} + +// handleNDPRA handles an NDP Router Advertisement message that arrived on n. +func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + n.mu.Lock() + defer n.mu.Unlock() + + n.mu.ndp.handleRA(ip, ra) +} + type networkEndpointKind int32 const ( @@ -977,7 +1578,7 @@ const ( // removing the permanent address from the NIC. permanent - // An expired permanent endoint is a permanent endoint that had its address + // An expired permanent endpoint is a permanent endpoint that had its address // removed from the NIC, and it is waiting to be removed once no more routes // hold a reference to it. This is achieved by decreasing its reference count // by 1. If its address is re-added before the endpoint is removed, its type @@ -997,11 +1598,11 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa n.mu.Lock() defer n.mu.Unlock() - eps, ok := n.packetEPs[netProto] + eps, ok := n.mu.packetEPs[netProto] if !ok { return tcpip.ErrNotSupported } - n.packetEPs[netProto] = append(eps, ep) + n.mu.packetEPs[netProto] = append(eps, ep) return nil } @@ -1010,21 +1611,40 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep n.mu.Lock() defer n.mu.Unlock() - eps, ok := n.packetEPs[netProto] + eps, ok := n.mu.packetEPs[netProto] if !ok { return } for i, epOther := range eps { if epOther == ep { - n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) return } } } +type networkEndpointConfigType int32 + +const ( + // A statically configured endpoint is an address that was added by + // some user-specified action (adding an explicit address, joining a + // multicast group). + static networkEndpointConfigType = iota + + // A SLAAC configured endpoint is an IPv6 endpoint that was added by + // SLAAC as per RFC 4862 section 5.5.3. + slaac + + // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by + // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are + // not expected to be valid (or preferred) forever; hence the term temporary. + slaacTemp +) + type referencedNetworkEndpoint struct { ep NetworkEndpoint + addr tcpip.AddressWithPrefix nic *NIC protocol tcpip.NetworkProtocolNumber @@ -1038,6 +1658,24 @@ type referencedNetworkEndpoint struct { // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind + + // configType is the method that was used to configure this endpoint. + // This must never change except during endpoint creation and promotion to + // permanent. + configType networkEndpointConfigType + + // deprecated indicates whether or not the endpoint should be considered + // deprecated. That is, when deprecated is true, other endpoints that are not + // deprecated should be preferred. + deprecated bool +} + +func (r *referencedNetworkEndpoint) address() tcpip.Address { + return r.addr.Address +} + +func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { + return r.addr } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { @@ -1049,17 +1687,44 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address -// has been removed), or the NIC to be in spoofing mode. +// packet. It requires the endpoint to not be marked expired (i.e., its address) +// has been removed) unless the NIC is in spoofing mode, or temporary. func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - return r.getKind() != permanentExpired || r.nic.spoofing + r.nic.mu.RLock() + defer r.nic.mu.RUnlock() + + return r.isValidForOutgoingRLocked() } -// isValidForIncoming returns true if the endpoint can accept an incoming -// packet. It requires the endpoint to not be marked expired (i.e., its address -// has been removed), or the NIC to be in promiscuous mode. -func (r *referencedNetworkEndpoint) isValidForIncoming() bool { - return r.getKind() != permanentExpired || r.nic.promiscuous +// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires +// r.nic.mu to be read locked. +func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { + if !r.nic.mu.enabled { + return false + } + + return r.isAssignedRLocked(r.nic.mu.spoofing) +} + +// isAssignedRLocked returns true if r is considered to be assigned to the NIC. +// +// r.nic.mu must be read locked. +func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool { + switch r.getKind() { + case permanentTentative: + return false + case permanentExpired: + return spoofingOrPromiscuous + default: + return true + } +} + +// expireLocked decrements the reference count and marks the permanent endpoint +// as expired. +func (r *referencedNetworkEndpoint) expireLocked() { + r.setKind(permanentExpired) + r.decRefLocked() } // decRef decrements the ref count and cleans up the endpoint once it reaches @@ -1071,14 +1736,11 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. Returns true if the endpoint was removed. -func (r *referencedNetworkEndpoint) decRefLocked() bool { +// locked. +func (r *referencedNetworkEndpoint) decRefLocked() { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) - return true } - - return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go new file mode 100644 index 000000000..d312a79eb --- /dev/null +++ b/pkg/tcpip/stack/nic_test.go @@ -0,0 +1,316 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +var _ LinkEndpoint = (*testLinkEndpoint)(nil) + +// A LinkEndpoint that throws away outgoing packets. +// +// We use this instead of the channel endpoint as the channel package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testLinkEndpoint struct { + dispatcher NetworkDispatcher +} + +// Attach implements LinkEndpoint.Attach. +func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements LinkEndpoint.IsAttached. +func (e *testLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements LinkEndpoint.MTU. +func (*testLinkEndpoint) MTU() uint32 { + return math.MaxUint16 +} + +// Capabilities implements LinkEndpoint.Capabilities. +func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { + return CapabilityResolutionRequired +} + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*testLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// Wait implements LinkEndpoint.Wait. +func (*testLinkEndpoint) Wait() {} + +// WritePacket implements LinkEndpoint.WritePacket. +func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) + +// An IPv6 NetworkEndpoint that throws away outgoing packets. +// +// We use this instead of ipv6.endpoint because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Endpoint struct { + nicID tcpip.NICID + linkEP LinkEndpoint + protocol *testIPv6Protocol +} + +// DefaultTTL implements NetworkEndpoint.DefaultTTL. +func (*testIPv6Endpoint) DefaultTTL() uint8 { + return 0 +} + +// MTU implements NetworkEndpoint.MTU. +func (e *testIPv6Endpoint) MTU() uint32 { + return e.linkEP.MTU() - header.IPv6MinimumSize +} + +// Capabilities implements NetworkEndpoint.Capabilities. +func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. +func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket implements NetworkEndpoint.WritePacket. +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements NetworkEndpoint.WritePackets. +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteHeaderIncludedPacket implements +// NetworkEndpoint.WriteHeaderIncludedPacket. +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +// NICID implements NetworkEndpoint.NICID. +func (e *testIPv6Endpoint) NICID() tcpip.NICID { + return e.nicID +} + +// HandlePacket implements NetworkEndpoint.HandlePacket. +func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { +} + +// Close implements NetworkEndpoint.Close. +func (*testIPv6Endpoint) Close() {} + +// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. +func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +var _ NetworkProtocol = (*testIPv6Protocol)(nil) + +// An IPv6 NetworkProtocol that supports the bare minimum to make a stack +// believe it supports IPv6. +// +// We use this instead of ipv6.protocol because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Protocol struct{} + +// Number implements NetworkProtocol.Number. +func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. +func (*testIPv6Protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. +func (*testIPv6Protocol) DefaultPrefixLen() int { + return header.IPv6AddressSize * 8 +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint implements NetworkProtocol.NewEndpoint. +func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { + return &testIPv6Endpoint{ + nicID: nicID, + linkEP: linkEP, + protocol: p, + } +} + +// SetOption implements NetworkProtocol.SetOption. +func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { + return nil +} + +// Option implements NetworkProtocol.Option. +func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { + return nil +} + +// Close implements NetworkProtocol.Close. +func (*testIPv6Protocol) Close() {} + +// Wait implements NetworkProtocol.Wait. +func (*testIPv6Protocol) Wait() {} + +// Parse implements NetworkProtocol.Parse. +func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + return 0, false, false +} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) + +// LinkAddressProtocol implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { + return nil +} + +// ResolveStaticAddress implements LinkAddressResolver. +func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if header.IsV6MulticastAddress(addr) { + return header.EthernetAddressFromMulticastIPv6Address(addr), true + } + return "", false +} + +// Test the race condition where a NIC is removed and an RS timer fires at the +// same time. +func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { + const ( + nicID = 1 + + maxRtrSolicitations = 5 + ) + + e := testLinkEndpoint{} + s := New(Options{ + NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, + NDPConfigs: NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: minimumRtrSolicitationInterval, + }, + }) + + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) + } + + s.mu.Lock() + // Wait for the router solicitation timer to fire and block trying to obtain + // the stack lock when doing link address resolution. + time.Sleep(minimumRtrSolicitationInterval * 2) + if err := s.removeNICLocked(nicID); err != nil { + t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) + } + s.mu.Unlock() +} + +func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { + // When the NIC is disabled, the only field that matters is the stats field. + // This test is limited to stats counter checks. + nic := NIC{ + stats: makeNICStats(), + } + + if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + t.Errorf("got DisabledRx.Packets = %d, want = 0", got) + } + if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) + } + if got := nic.stats.Rx.Packets.Value(); got != 0 { + t.Errorf("got Rx.Packets = %d, want = 0", got) + } + if got := nic.stats.Rx.Bytes.Value(); got != 0 { + t.Errorf("got Rx.Bytes = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ + Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), + })) + + if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + t.Errorf("got DisabledRx.Packets = %d, want = 1", got) + } + if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) + } + if got := nic.stats.Rx.Packets.Value(); got != 0 { + t.Errorf("got Rx.Packets = %d, want = 0", got) + } + if got := nic.stats.Rx.Bytes.Value(); got != 0 { + t.Errorf("got Rx.Bytes = %d, want = 0", got) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go new file mode 100644 index 000000000..e1ec15487 --- /dev/null +++ b/pkg/tcpip/stack/nud.go @@ -0,0 +1,466 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "math" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // defaultBaseReachableTime is the default base duration for computing the + // random reachable time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of a uniformly distributed random value between the minimum and + // maximum random factors, multiplied by the base reachable time. Using a + // random component eliminates the possibility that Neighbor Unreachability + // Detection messages will synchronize with each other. + // + // Default taken from REACHABLE_TIME of RFC 4861 section 10. + defaultBaseReachableTime = 30 * time.Second + + // minimumBaseReachableTime is the minimum base duration for computing the + // random reachable time. + // + // Minimum = 1ms + minimumBaseReachableTime = time.Millisecond + + // defaultMinRandomFactor is the default minimum value of the random factor + // used for computing reachable time. + // + // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10. + defaultMinRandomFactor = 0.5 + + // defaultMaxRandomFactor is the default maximum value of the random factor + // used for computing reachable time. + // + // The default value depends on the value of MinRandomFactor. + // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10, + // the value from the RFC will be used; otherwise, the default is + // MinRandomFactor multiplied by three. + defaultMaxRandomFactor = 1.5 + + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + + // defaultDelayFirstProbeTime is the default duration to wait for a + // non-Neighbor-Discovery related protocol to reconfirm reachability after + // entering the DELAY state. After this time, a reachability probe will be + // sent and the entry will transition to the PROBE state. + // + // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10. + defaultDelayFirstProbeTime = 5 * time.Second + + // defaultMaxMulticastProbes is the default number of reachabililty probes + // to send before concluding negative reachability and deleting the neighbor + // entry from the INCOMPLETE state. + // + // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10. + defaultMaxMulticastProbes = 3 + + // defaultMaxUnicastProbes is the default number of reachability probes to + // send before concluding retransmission from within the PROBE state should + // cease and the entry SHOULD be deleted. + // + // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10. + defaultMaxUnicastProbes = 3 + + // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD + // delay sending a response for a random time between 0 and this time, if the + // target address is an anycast address. + // + // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10. + defaultMaxAnycastDelayTime = time.Second + + // defaultMaxReachbilityConfirmations is the default amount of unsolicited + // reachability confirmation messages a node MAY send to all-node multicast + // address when it determines its link-layer address has changed. + // + // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. + defaultMaxReachbilityConfirmations = 3 + + // defaultUnreachableTime is the default duration for how long an entry will + // remain in the FAILED state before being removed from the neighbor cache. + // + // Note, there is no equivalent protocol constant defined in RFC 4861. It + // leaves the specifics of any garbage collection mechanism up to the + // implementation. + defaultUnreachableTime = 5 * time.Second +) + +// NUDDispatcher is the interface integrators of netstack must implement to +// receive and handle NUD related events. +type NUDDispatcher interface { + // OnNeighborAdded will be called when a new entry is added to a NIC's (with + // ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) + // neighbor table changes state and/or link address. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborRemoved will be called when an entry is removed from a NIC's + // (with ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) +} + +// ReachabilityConfirmationFlags describes the flags used within a reachability +// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, +// respectively). +type ReachabilityConfirmationFlags struct { + // Solicited indicates that the advertisement was sent in response to a + // reachability probe. + Solicited bool + + // Override indicates that the reachability confirmation should override an + // existing neighbor cache entry and update the cached link-layer address. + // When Override is not set the confirmation will not update a cached + // link-layer address, but will update an existing neighbor cache entry for + // which no link-layer address is known. + Override bool + + // IsRouter indicates that the sender is a router. + IsRouter bool +} + +// NUDHandler communicates external events to the Neighbor Unreachability +// Detection state machine, which is implemented per-interface. This is used by +// network endpoints to inform the Neighbor Cache of probes and confirmations. +type NUDHandler interface { + // HandleProbe processes an incoming neighbor probe (e.g. ARP request or + // Neighbor Solicitation for ARP or NDP, respectively). Validation of the + // probe needs to be performed before calling this function since the + // Neighbor Cache doesn't have access to view the NIC's assigned addresses. + HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) + + // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP + // reply or Neighbor Advertisement for ARP or NDP, respectively). + HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) + + // HandleUpperLevelConfirmation processes an incoming upper-level protocol + // (e.g. TCP acknowledgements) reachability confirmation. + HandleUpperLevelConfirmation(addr tcpip.Address) +} + +// NUDConfigurations is the NUD configurations for the netstack. This is used +// by the neighbor cache to operate the NUD state machine on each device in the +// local network. +type NUDConfigurations struct { + // BaseReachableTime is the base duration for computing the random reachable + // time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of uniformly distributed random value between minRandomFactor and + // maxRandomFactor multiplied by baseReachableTime. Using a random component + // eliminates the possibility that Neighbor Unreachability Detection messages + // will synchronize with each other. + // + // After this time, a neighbor entry will transition from REACHABLE to STALE + // state. + // + // Must be greater than 0. + BaseReachableTime time.Duration + + // LearnBaseReachableTime enables learning BaseReachableTime during runtime + // from the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option. + LearnBaseReachableTime bool + + // MinRandomFactor is the minimum value of the random factor used for + // computing reachable time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be greater than 0. + MinRandomFactor float32 + + // MaxRandomFactor is the maximum value of the random factor used for + // computing reachabile time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be great than or equal to MinRandomFactor. + MaxRandomFactor float32 + + // RetransmitTimer is the duration between retransmission of reachability + // probes in the PROBE state. + RetransmitTimer time.Duration + + // LearnRetransmitTimer enables learning RetransmitTimer during runtime from + // the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option. + LearnRetransmitTimer bool + + // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery + // related protocol to reconfirm reachability after entering the DELAY state. + // After this time, a reachability probe will be sent and the entry will + // transition to the PROBE state. + // + // Must be greater than 0. + DelayFirstProbeTime time.Duration + + // MaxMulticastProbes is the number of reachability probes to send before + // concluding negative reachability and deleting the neighbor entry from the + // INCOMPLETE state. + // + // Must be greater than 0. + MaxMulticastProbes uint32 + + // MaxUnicastProbes is the number of reachability probes to send before + // concluding retransmission from within the PROBE state should cease and + // entry SHOULD be deleted. + // + // Must be greater than 0. + MaxUnicastProbes uint32 + + // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a + // response for a random time between 0 and this time, if the target address + // is an anycast address. + // + // TODO(gvisor.dev/issue/2242): Use this option when sending solicited + // neighbor confirmations to anycast addresses and proxying neighbor + // confirmations. + MaxAnycastDelayTime time.Duration + + // MaxReachabilityConfirmations is the number of unsolicited reachability + // confirmation messages a node MAY send to all-node multicast address when + // it determines its link-layer address has changed. + // + // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD + // configuration option is necessary. + MaxReachabilityConfirmations uint32 + + // UnreachableTime describes how long an entry will remain in the FAILED + // state before being removed from the neighbor cache. + UnreachableTime time.Duration +} + +// DefaultNUDConfigurations returns a NUDConfigurations populated with default +// values defined by RFC 4861 section 10. +func DefaultNUDConfigurations() NUDConfigurations { + return NUDConfigurations{ + BaseReachableTime: defaultBaseReachableTime, + LearnBaseReachableTime: true, + MinRandomFactor: defaultMinRandomFactor, + MaxRandomFactor: defaultMaxRandomFactor, + RetransmitTimer: defaultRetransmitTimer, + LearnRetransmitTimer: true, + DelayFirstProbeTime: defaultDelayFirstProbeTime, + MaxMulticastProbes: defaultMaxMulticastProbes, + MaxUnicastProbes: defaultMaxUnicastProbes, + MaxAnycastDelayTime: defaultMaxAnycastDelayTime, + MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, + UnreachableTime: defaultUnreachableTime, + } +} + +// resetInvalidFields modifies an invalid NDPConfigurations with valid values. +// If invalid values are present in c, the corresponding default values will be +// used instead. This is needed to check, and conditionally fix, user-specified +// NUDConfigurations. +func (c *NUDConfigurations) resetInvalidFields() { + if c.BaseReachableTime < minimumBaseReachableTime { + c.BaseReachableTime = defaultBaseReachableTime + } + if c.MinRandomFactor <= 0 { + c.MinRandomFactor = defaultMinRandomFactor + } + if c.MaxRandomFactor < c.MinRandomFactor { + c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor) + } + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } + if c.DelayFirstProbeTime == 0 { + c.DelayFirstProbeTime = defaultDelayFirstProbeTime + } + if c.MaxMulticastProbes == 0 { + c.MaxMulticastProbes = defaultMaxMulticastProbes + } + if c.MaxUnicastProbes == 0 { + c.MaxUnicastProbes = defaultMaxUnicastProbes + } + if c.UnreachableTime == 0 { + c.UnreachableTime = defaultUnreachableTime + } +} + +// calcMaxRandomFactor calculates the maximum value of the random factor used +// for computing reachable time. This function is necessary for when the +// default specified in RFC 4861 section 10 is less than the current +// MinRandomFactor. +// +// Assumes minRandomFactor is positive since validation of the minimum value +// should come before the validation of the maximum. +func calcMaxRandomFactor(minRandomFactor float32) float32 { + if minRandomFactor > defaultMaxRandomFactor { + return minRandomFactor * 3 + } + return defaultMaxRandomFactor +} + +// A Rand is a source of random numbers. +type Rand interface { + // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). + Float32() float32 +} + +// NUDState stores states needed for calculating reachable time. +type NUDState struct { + rng Rand + + // mu protects the fields below. + // + // It is necessary for NUDState to handle its own locking since neighbor + // entries may access the NUD state from within the goroutine spawned by + // time.AfterFunc(). This goroutine may run concurrently with the main + // process for controlling the neighbor cache and would otherwise introduce + // race conditions if NUDState was not locked properly. + mu sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 +} + +// NewNUDState returns new NUDState using c as configuration and the specified +// random number generator for use in recomputing ReachableTime. +func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { + s := &NUDState{ + rng: rng, + } + s.config = c + return s +} + +// Config returns the NUD configuration. +func (s *NUDState) Config() NUDConfigurations { + s.mu.RLock() + defer s.mu.RUnlock() + return s.config +} + +// SetConfig replaces the existing NUD configurations with c. +func (s *NUDState) SetConfig(c NUDConfigurations) { + s.mu.Lock() + defer s.mu.Unlock() + s.config = c +} + +// ReachableTime returns the duration to wait for a REACHABLE entry to +// transition into STALE after inactivity. This value is recalculated for new +// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the +// algorithm defined in RFC 4861 section 6.3.2. +func (s *NUDState) ReachableTime() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + + if time.Now().After(s.expiration) || + s.config.BaseReachableTime != s.prevBaseReachableTime || + s.config.MinRandomFactor != s.prevMinRandomFactor || + s.config.MaxRandomFactor != s.prevMaxRandomFactor { + return s.recomputeReachableTimeLocked() + } + return s.reachableTime +} + +// recomputeReachableTimeLocked forces a recalculation of ReachableTime using +// the algorithm defined in RFC 4861 section 6.3.2. +// +// This SHOULD automatically be invoked during certain situations, as per +// RFC 4861 section 6.3.4: +// +// If the received Reachable Time value is non-zero, the host SHOULD set its +// BaseReachableTime variable to the received value. If the new value +// differs from the previous value, the host SHOULD re-compute a new random +// ReachableTime value. ReachableTime is computed as a uniformly +// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR +// times the BaseReachableTime. Using a random component eliminates the +// possibility that Neighbor Unreachability Detection messages will +// synchronize with each other. +// +// In most cases, the advertised Reachable Time value will be the same in +// consecutive Router Advertisements, and a host's BaseReachableTime rarely +// changes. In such cases, an implementation SHOULD ensure that a new +// random value gets re-computed at least once every few hours. +// +// s.mu MUST be locked for writing. +func (s *NUDState) recomputeReachableTimeLocked() time.Duration { + s.prevBaseReachableTime = s.config.BaseReachableTime + s.prevMinRandomFactor = s.config.MinRandomFactor + s.prevMaxRandomFactor = s.config.MaxRandomFactor + + randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + + // Check for overflow, given that minRandomFactor and maxRandomFactor are + // guaranteed to be positive numbers. + if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { + s.reachableTime = time.Duration(math.MaxInt64) + } else if randomFactor == 1 { + // Avoid loss of precision when a large base reachable time is used. + s.reachableTime = s.config.BaseReachableTime + } else { + reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) + s.reachableTime = time.Duration(reachableTime) + } + + s.expiration = time.Now().Add(2 * time.Hour) + return s.reachableTime +} diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go new file mode 100644 index 000000000..2494ee610 --- /dev/null +++ b/pkg/tcpip/stack/nud_test.go @@ -0,0 +1,795 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 + defaultMaxAnycastDelayTime = time.Second + defaultMaxReachbilityConfirmations = 3 + defaultUnreachableTime = 5 * time.Second + + defaultFakeRandomNum = 0.5 +) + +// fakeRand is a deterministic random number generator. +type fakeRand struct { + num float32 +} + +var _ stack.Rand = (*fakeRand)(nil) + +func (f *fakeRand) Float32() float32 { + return f.num +} + +// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NUD configurations using an invalid NICID. +func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to retrieve NUD configurations when the +// stack doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to set NUD configurations when the stack +// doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + } +} + +// TestDefaultNUDConfigurationIsValid verifies that calling +// resetInvalidFields() on the result of DefaultNUDConfigurations() does not +// change anything. DefaultNUDConfigurations() should return a valid +// NUDConfigurations. +func TestDefaultNUDConfigurations(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + c, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got, want := c, stack.DefaultNUDConfigurations(); got != want { + t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + } +} + +func TestNUDConfigurationsBaseReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + baseReachableTime: 0, + want: defaultBaseReachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + baseReachableTime: time.Millisecond, + want: time.Millisecond, + }, + { + name: "MoreThanDefaultBaseReachableTime", + baseReachableTime: 2 * defaultBaseReachableTime, + want: 2 * defaultBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = test.baseReachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.BaseReachableTime; got != test.want { + t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMinRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: -1, + want: defaultMinRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: 0, + want: defaultMinRandomFactor, + }, + // Valid cases + { + name: "MoreThanZero", + minRandomFactor: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MinRandomFactor; got != test.want { + t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + maxRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: -1, + want: defaultMaxRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: 0, + want: defaultMaxRandomFactor, + }, + { + name: "LessThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 0.99, + want: defaultMaxRandomFactor, + }, + { + name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault", + minRandomFactor: defaultMaxRandomFactor * 2, + maxRandomFactor: defaultMaxRandomFactor, + want: defaultMaxRandomFactor * 6, + }, + // Valid cases + { + name: "EqualToMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor, + want: defaultMinRandomFactor, + }, + { + name: "MoreThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 1.1, + want: defaultMinRandomFactor * 1.1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxRandomFactor; got != test.want { + t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsRetransmitTimer(t *testing.T) { + tests := []struct { + name string + retransmitTimer time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + retransmitTimer: 0, + want: defaultRetransmitTimer, + }, + { + name: "LessThanMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer - time.Nanosecond, + want: defaultRetransmitTimer, + }, + // Valid cases + { + name: "EqualToMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer, + want: minimumBaseReachableTime, + }, + { + name: "LargetThanMinimumRetransmitTimer", + retransmitTimer: 2 * minimumBaseReachableTime, + want: 2 * minimumBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.RetransmitTimer = test.retransmitTimer + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.RetransmitTimer; got != test.want { + t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { + tests := []struct { + name string + delayFirstProbeTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + delayFirstProbeTime: 0, + want: defaultDelayFirstProbeTime, + }, + // Valid cases + { + name: "MoreThanZero", + delayFirstProbeTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.DelayFirstProbeTime = test.delayFirstProbeTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.DelayFirstProbeTime; got != test.want { + t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { + tests := []struct { + name string + maxMulticastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxMulticastProbes: 0, + want: defaultMaxMulticastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxMulticastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxMulticastProbes = test.maxMulticastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxMulticastProbes; got != test.want { + t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { + tests := []struct { + name string + maxUnicastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxUnicastProbes: 0, + want: defaultMaxUnicastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxUnicastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxUnicastProbes = test.maxUnicastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxUnicastProbes; got != test.want { + t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsUnreachableTime(t *testing.T) { + tests := []struct { + name string + unreachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + unreachableTime: 0, + want: defaultUnreachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + unreachableTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.UnreachableTime = test.unreachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.UnreachableTime; got != test.want { + t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +// TestNUDStateReachableTime verifies the correctness of the ReachableTime +// computation. +func TestNUDStateReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "AllZeros", + baseReachableTime: 0, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMaxRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMinRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 1, + want: time.Duration(defaultFakeRandomNum * float32(time.Second)), + }, + { + name: "FractionalRandomFactor", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 0.001, + maxRandomFactor: 0.002, + want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)), + }, + { + name: "MinAndMaxRandomFactorsEqual", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Second, + }, + { + name: "MinAndMaxRandomFactorsDifferent", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 2, + want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)), + }, + { + name: "MaxInt64", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Duration(math.MaxInt64), + }, + { + name: "Overflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1.5, + maxRandomFactor: 1.5, + want: time.Duration(math.MaxInt64), + }, + { + name: "DoubleOverflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 2.5, + maxRandomFactor: 2.5, + want: time.Duration(math.MaxInt64), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.NUDConfigurations{ + BaseReachableTime: test.baseReachableTime, + MinRandomFactor: test.minRandomFactor, + MaxRandomFactor: test.maxRandomFactor, + } + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} + +// TestNUDStateRecomputeReachableTime exercises the ReachableTime function +// twice to verify recomputation of reachable time when the min random factor, +// max random factor, or base reachable time changes. +func TestNUDStateRecomputeReachableTime(t *testing.T) { + const defaultBase = time.Second + const defaultMin = 2.0 * defaultMaxRandomFactor + const defaultMax = 3.0 * defaultMaxRandomFactor + + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "BaseReachableTime", + baseReachableTime: 2 * defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMax, + want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + { + name: "MinRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMax, + maxRandomFactor: defaultMax, + want: time.Duration(defaultMax * float32(defaultBase)), + }, + { + name: "MaxRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMin, + want: time.Duration(defaultMin * float32(defaultBase)), + }, + { + name: "BothRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)), + }, + { + name: "BaseReachableTimeAndBothRandomFactors", + baseReachableTime: 2 * defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = defaultBase + c.MinRandomFactor = defaultMin + c.MaxRandomFactor = defaultMax + + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + old := s.ReachableTime() + + if got, want := s.ReachableTime(), old; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Check for recomputation when changing the min random factor, the max + // random factor, the base reachability time, or any permutation of those + // three options. + c.BaseReachableTime = test.baseReachableTime + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + s.SetConfig(c) + + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Verify that ReachableTime isn't recomputed when none of the + // configuration options change. The random factor is changed so that if + // a recompution were to occur, ReachableTime would change. + rng.num = defaultFakeRandomNum / 2.0 + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go new file mode 100644 index 000000000..17b8beebb --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer.go @@ -0,0 +1,299 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +type headerType int + +const ( + linkHeader headerType = iota + networkHeader + transportHeader + numHeaderType +) + +// PacketBufferOptions specifies options for PacketBuffer creation. +type PacketBufferOptions struct { + // ReserveHeaderBytes is the number of bytes to reserve for headers. Total + // number of bytes pushed onto the headers must not exceed this value. + ReserveHeaderBytes int + + // Data is the initial unparsed data for the new packet. If set, it will be + // owned by the new packet. + Data buffer.VectorisedView +} + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. +// +// The whole packet is expected to be a series of bytes in the following order: +// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be +// empty. Use of PacketBuffer in any other order is unsupported. +// +// PacketBuffer must be created with NewPacketBuffer. +type PacketBuffer struct { + _ sync.NoCopy + + // PacketBufferEntry is used to build an intrusive list of + // PacketBuffers. + PacketBufferEntry + + // Data holds the payload of the packet. + // + // For inbound packets, Data is initially the whole packet. Then gets moved to + // headers via PacketHeader.Consume, when the packet is being parsed. + // + // For outbound packets, Data is the innermost layer, defined by the protocol. + // Headers are pushed in front of it via PacketHeader.Push. + // + // The bytes backing Data are immutable, a.k.a. users shouldn't write to its + // backing storage. + Data buffer.VectorisedView + + // headers stores metadata about each header. + headers [numHeaderType]headerInfo + + // header is the internal storage for outbound packets. Headers will be pushed + // (prepended) on this storage as the packet is being constructed. + // + // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and + // data are held in the same underlying buffer storage. + header buffer.Prependable + + // NetworkProtocol is only valid when NetworkHeader is set. + // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol + // numbers in registration APIs that take a PacketBuffer. + NetworkProtocolNumber tcpip.NetworkProtocolNumber + + // Hash is the transport layer hash of this packet. A value of zero + // indicates no valid hash has been set. + Hash uint32 + + // Owner is implemented by task to get the uid and gid. + // Only set for locally generated packets. + Owner tcpip.PacketOwner + + // The following fields are only set by the qdisc layer when the packet + // is added to a queue. + EgressRoute *Route + GSOOptions *GSO + + // NatDone indicates if the packet has been manipulated as per NAT + // iptables rule. + NatDone bool + + // PktType indicates the SockAddrLink.PacketType of the packet as defined in + // https://www.man7.org/linux/man-pages/man7/packet.7.html. + PktType tcpip.PacketType +} + +// NewPacketBuffer creates a new PacketBuffer with opts. +func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { + pk := &PacketBuffer{ + Data: opts.Data, + } + if opts.ReserveHeaderBytes != 0 { + pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + } + return pk +} + +// ReservedHeaderBytes returns the number of bytes initially reserved for +// headers. +func (pk *PacketBuffer) ReservedHeaderBytes() int { + return pk.header.UsedLength() + pk.header.AvailableLength() +} + +// AvailableHeaderBytes returns the number of bytes currently available for +// headers. This is relevant to PacketHeader.Push method only. +func (pk *PacketBuffer) AvailableHeaderBytes() int { + return pk.header.AvailableLength() +} + +// LinkHeader returns the handle to link-layer header. +func (pk *PacketBuffer) LinkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: linkHeader, + } +} + +// NetworkHeader returns the handle to network-layer header. +func (pk *PacketBuffer) NetworkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: networkHeader, + } +} + +// TransportHeader returns the handle to transport-layer header. +func (pk *PacketBuffer) TransportHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: transportHeader, + } +} + +// HeaderSize returns the total size of all headers in bytes. +func (pk *PacketBuffer) HeaderSize() int { + // Note for inbound packets (Consume called), headers are not stored in + // pk.header. Thus, calculation of size of each header is needed. + var size int + for i := range pk.headers { + size += len(pk.headers[i].buf) + } + return size +} + +// Size returns the size of packet in bytes. +func (pk *PacketBuffer) Size() int { + return pk.HeaderSize() + pk.Data.Size() +} + +// Views returns the underlying storage of the whole packet. +func (pk *PacketBuffer) Views() []buffer.View { + // Optimization for outbound packets that headers are in pk.header. + useHeader := true + for i := range pk.headers { + if !canUseHeader(&pk.headers[i]) { + useHeader = false + break + } + } + + dataViews := pk.Data.Views() + + var vs []buffer.View + if useHeader { + vs = make([]buffer.View, 0, 1+len(dataViews)) + vs = append(vs, pk.header.View()) + } else { + vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) + for i := range pk.headers { + if v := pk.headers[i].buf; len(v) > 0 { + vs = append(vs, v) + } + } + } + return append(vs, dataViews...) +} + +func canUseHeader(h *headerInfo) bool { + // h.offset will be negative if the header was pushed in to prependable + // portion, or doesn't matter when it's empty. + return len(h.buf) == 0 || h.offset < 0 +} + +func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + } + h.buf = buffer.View(pk.header.Prepend(size)) + h.offset = -pk.header.UsedLength() + return h.buf +} + +func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) + } + v, ok := pk.Data.PullUp(size) + if !ok { + return + } + pk.Data.TrimFront(size) + h.buf = v + return h.buf, true +} + +// Clone makes a shallow copy of pk. +// +// Clone should be called in such cases so that no modifications is done to +// underlying packet payload. +func (pk *PacketBuffer) Clone() *PacketBuffer { + newPk := &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + } + return newPk +} + +// headerInfo stores metadata about a header in a packet. +type headerInfo struct { + // buf is the memorized slice for both prepended and consumed header. + // When header is prepended, buf serves as memorized value, which is a slice + // of pk.header. When header is consumed, buf is the slice pulled out from + // pk.Data, which is the only place to hold this header. + buf buffer.View + + // offset will be a negative number denoting the offset where this header is + // from the end of pk.header, if it is prepended. Otherwise, zero. + offset int +} + +// PacketHeader is a handle object to a header in the underlying packet. +type PacketHeader struct { + pk *PacketBuffer + typ headerType +} + +// View returns the underlying storage of h. +func (h PacketHeader) View() buffer.View { + return h.pk.headers[h.typ].buf +} + +// Push pushes size bytes in the front of its residing packet, and returns the +// backing storage. Callers may only call one of Push or Consume once on each +// header in the lifetime of the underlying packet. +func (h PacketHeader) Push(size int) buffer.View { + return h.pk.push(h.typ, size) +} + +// Consume moves the first size bytes of the unparsed data portion in the packet +// to h, and returns the backing storage. In the case of data is shorter than +// size, consumed will be false, and the state of h will not be affected. +// Callers may only call one of Push or Consume once on each header in the +// lifetime of the underlying packet. +func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { + return h.pk.consume(h.typ, size) +} + +// PayloadSince returns packet payload starting from and including a particular +// header. This method isn't optimized and should be used in test only. +func PayloadSince(h PacketHeader) buffer.View { + var v buffer.View + for _, hinfo := range h.pk.headers[h.typ:] { + v = append(v, hinfo.buf...) + } + return append(v, h.pk.Data.ToView()...) +} diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go new file mode 100644 index 000000000..c6fa8da5f --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestPacketHeaderPush(t *testing.T) { + for _, test := range []struct { + name string + reserved int + link []byte + network []byte + transport []byte + data []byte + }{ + { + name: "construct empty packet", + }, + { + name: "construct link header only packet", + reserved: 60, + link: makeView(10), + }, + { + name: "construct link and network header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + }, + { + name: "construct header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + }, + { + name: "construct data only packet", + data: makeView(40), + }, + { + name: "construct L3 packet", + reserved: 60, + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + { + name: "construct L2 packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + allHdrSize := len(test.link) + len(test.network) + len(test.transport) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Push headers. + if v := test.transport; len(v) > 0 { + copy(pk.TransportHeader().Push(len(v)), v) + } + if v := test.network; len(v) > 0 { + copy(pk.NetworkHeader().Push(len(v)), v) + } + if v := test.link; len(v) > 0 { + copy(pk.LinkHeader().Push(len(v)), v) + } + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), allHdrSize+len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), + concatViews(test.link, test.network, test.transport, test.data)) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(test.link, test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(test.transport, test.data)) + }) + } +} + +func TestPacketHeaderConsume(t *testing.T) { + for _, test := range []struct { + name string + data []byte + link int + network int + transport int + }{ + { + name: "parse L2 packet", + data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), + link: 10, + network: 20, + transport: 30, + }, + { + name: "parse L3 packet", + data: concatViews(makeView(20), makeView(30), makeView(40)), + network: 20, + transport: 30, + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Consume headers. + if size := test.link; size > 0 { + if _, ok := pk.LinkHeader().Consume(size); !ok { + t.Fatalf("pk.LinkHeader().Consume() = false, want true") + } + } + if size := test.network; size > 0 { + if _, ok := pk.NetworkHeader().Consume(size); !ok { + t.Fatalf("pk.NetworkHeader().Consume() = false, want true") + } + } + if size := test.transport; size > 0 { + if _, ok := pk.TransportHeader().Consume(size); !ok { + t.Fatalf("pk.TransportHeader().Consume() = false, want true") + } + } + + allHdrSize := test.link + test.network + test.transport + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), 0; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), 0; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + // After state of pk. + var ( + link = test.data[:test.link] + network = test.data[test.link:][:test.network] + transport = test.data[test.link+test.network:][:test.transport] + payload = test.data[allHdrSize:] + ) + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(link, network, transport, payload)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(network, transport, payload)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(transport, payload)) + }) + } +} + +func TestPacketHeaderConsumeDataTooShort(t *testing.T) { + data := makeView(10) + + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + // Consume should fail if pkt.Data is too short. + if _, ok := pk.LinkHeader().Consume(11); ok { + t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.NetworkHeader().Consume(11); ok { + t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.TransportHeader().Consume(11); ok { + t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") + } + + // Check packet should look the same as initial packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(data).ToVectorisedView(), + }) +} + +func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Second push should have panicked") + }) + } +} + +func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { + if _, ok := h.Consume(headerSize); !ok { + t.Fatal("First consume should succeed") + } + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Second consume should have panicked") + }) + } +} + +func TestPacketHeaderPushThenConsumePanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Consume should have panicked") + }) + } +} + +func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Consume(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Push should have panicked") + }) + } +} + +func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { + t.Helper() + reserved := opts.ReserveHeaderBytes + if got, want := pk.ReservedHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), 0; got != want { + t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) + } + data := opts.Data.ToView() + if got, want := pk.Size(), len(data); got != want { + t.Errorf("Initial pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) + // Check the initial values for each header. + checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) + checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) + checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) + // Check the initial valies for PayloadSince. + checkViewEqual(t, "Initial PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), data) +} + +func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { + t.Helper() + checkViewEqual(t, name+".View()", h.View(), want) +} + +func checkViewEqual(t *testing.T, what string, got, want buffer.View) { + t.Helper() + if !bytes.Equal(got, want) { + t.Errorf("%s = %x, want %x", what, got, want) + } +} + +func makeView(size int) buffer.View { + b := byte(size) + return bytes.Repeat([]byte{b}, size) +} + +func concatViews(views ...buffer.View) buffer.View { + var all buffer.View + for _, v := range views { + all = append(all, v...) + } + return all +} diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go new file mode 100644 index 000000000..421fb5c15 --- /dev/null +++ b/pkg/tcpip/stack/rand.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + mathrand "math/rand" + + "gvisor.dev/gvisor/pkg/sync" +) + +// lockedRandomSource provides a threadsafe rand.Source. +type lockedRandomSource struct { + mu sync.Mutex + src mathrand.Source +} + +func (r *lockedRandomSource) Int63() (n int64) { + r.mu.Lock() + n = r.src.Int63() + r.mu.Unlock() + return n +} + +func (r *lockedRandomSource) Seed(seed int64) { + r.mu.Lock() + r.src.Seed(seed) + r.mu.Unlock() +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0869fb084..aca2f77f8 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/waiter" ) @@ -51,8 +52,11 @@ type TransportEndpointID struct { type ControlType int // The following are the allowed values for ControlType values. +// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlPacketTooBig ControlType = iota + ControlNetworkUnreachable ControlType = iota + ControlNoRoute + ControlPacketTooBig ControlPortUnreachable ControlUnknown ) @@ -60,13 +64,34 @@ const ( // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { + // UniqueID returns an unique ID for this transport endpoint. + UniqueID() uint64 + // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + // this transport endpoint. It sets pkt.TransportHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g., + // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + // HandleControlPacket takes ownership of pkt. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) + + // Abort initiates an expedited endpoint teardown. It puts the endpoint + // in a closed state and frees all resources associated with it. This + // cleanup may happen asynchronously. Wait can be used to block on this + // asynchronous cleanup. + Abort() + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // An endpoint can be requested to stop its worker goroutines by calling + // its Close method. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // RawTransportEndpoint is the interface that needs to be implemented by raw @@ -77,7 +102,9 @@ type RawTransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. The packet contains all data from the link // layer up. - HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt *PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -93,7 +120,9 @@ type PacketEndpoint interface { // // linkHeader may have a length of 0, in which case the PacketEndpoint // should construct its own ethernet header for applications. - HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View) + // + // HandlePacket takes ownership of pkt. + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -123,7 +152,9 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + // + // HandleUnknownDestinationPacket takes ownership of pkt. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -134,6 +165,18 @@ type TransportProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() + + // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does + // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() < + // MinimumPacketSize() + Parse(pkt *PacketBuffer) (ok bool) } // TransportDispatcher contains the methods used by the network stack to deliver @@ -141,13 +184,21 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. It also returns the network layer - // header for the enpoint to inspect or pass up the stack. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) + // transport protocol endpoint. + // + // pkt.NetworkHeader must be set before calling DeliverTransportPacket. + // + // DeliverTransportPacket takes ownership of pkt. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) + // + // pkt.NetworkHeader must be set before calling + // DeliverTransportControlPacket. + // + // DeliverTransportControlPacket takes ownership of pkt. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -198,32 +249,34 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error + // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // already been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and - // protocol. - WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) + // protocol. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network - // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error - - // ID returns the network protocol endpoint ID. - ID() *NetworkEndpointID - - // PrefixLen returns the network endpoint's subnet prefix length in bits. - PrefixLen() int + // header to the given destination address. It takes ownership of pkt. + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error // NICID returns the id of the NIC this endpoint belongs to. NICID() tcpip.NICID // HandlePacket is called by the link layer when new packets arrive to - // this network endpoint. - HandlePacket(r *Route, vv buffer.VectorisedView) + // this network endpoint. It sets pkt.NetworkHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() + + // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for + // this endpoint. + NetworkProtocolNumber() tcpip.NetworkProtocolNumber } // NetworkProtocol is the interface that needs to be implemented by network @@ -240,12 +293,12 @@ type NetworkProtocol interface { // DefaultPrefixLen returns the protocol's default prefix length. DefaultPrefixLen() int - // ParsePorts returns the source and destination addresses stored in a + // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -256,16 +309,45 @@ type NetworkProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() + + // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It + // returns: + // - The encapsulated protocol, if present. + // - Whether there is an encapsulated transport protocol payload (e.g. ARP + // does not encapsulate anything). + // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader. + Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) } // NetworkDispatcher contains the methods used by the network stack to deliver -// packets to the appropriate network endpoint after it has been handled by -// the data link layer. +// inbound/outbound packets to the appropriate network/packet(if any) endpoints. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint - // and hands the packet over for further processing. linkHeader may have - // length 0 when the caller does not have ethernet data. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) + // and hands the packet over for further processing. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverNetworkPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverNetworkPacket takes ownership of pkt. + DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) + + // DeliverOutboundPacket is called by link layer when a packet is being + // sent out. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverOutboundPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverOutboundPacket takes ownership of pkt. + DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -296,7 +378,9 @@ const ( // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. +// out through the implementer's data link endpoint. When a link header exists, +// it sets each PacketBuffer's LinkHeader field before passing it up the +// stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a @@ -318,28 +402,33 @@ type LinkEndpoint interface { // link endpoint. LinkAddress() tcpip.LinkAddress - // WritePacket writes a packet with the given protocol through the given - // route. + // WritePacket writes a packet with the given protocol through the + // given route. It takes ownership of pkt. pkt.NetworkHeader and + // pkt.TransportHeader must have already been set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the - // given route. + // given route. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. - WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error + // should already have an ethernet header. It takes ownership of vv. + WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. + // + // Attach will be called with a nil dispatcher if the receiver's associated + // NIC is being removed. Attach(dispatcher NetworkDispatcher) // IsAttached returns whether a NetworkDispatcher is attached to the @@ -354,6 +443,15 @@ type LinkEndpoint interface { // Wait will not block if the endpoint hasn't started any goroutines // yet, even if it might later. Wait() + + // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint. + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 + ARPHardwareType() header.ARPHardwareType + + // AddHeader adds a link layer header to pkt if required. + AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -362,7 +460,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. @@ -374,12 +472,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. - // The request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts + // the request on the local network if remoteLinkAddr is the zero value. The + // request is sent on linkEP with localAddr as the source. // // A valid response will cause the discovery protocol's network // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error + LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the @@ -397,10 +496,10 @@ type LinkAddressResolver interface { type LinkAddressCache interface { // CheckLocalAddress determines if the given local address exists, and if it // does not exist. - CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID + CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver @@ -411,10 +510,10 @@ type LinkAddressCache interface { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). - GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1a0a51b57..e267bebb0 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,7 +17,6 @@ package stack import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -49,6 +48,10 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping + + // directedBroadcast indicates whether this route is sending a directed + // broadcast packet. + directedBroadcast bool } // makeRoute initializes a new route. It takes ownership of the provided @@ -107,6 +110,12 @@ func (r *Route) GSOMaxSize() uint32 { return 0 } +// ResolveWith immediately resolves a route with the specified remote link +// address. +func (r *Route) ResolveWith(addr tcpip.LinkAddress) { + r.RemoteLinkAddress = addr +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -114,6 +123,8 @@ func (r *Route) GSOMaxSize() uint32 { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). +// +// The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if !r.IsResolutionRequired() { // Nothing to do if there is no cache (which does the resolution on cache miss) or @@ -149,77 +160,72 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. +// +// The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.Loop) + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Size() + + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } return err } -// PacketDescriptor is a packet descriptor which contains a packet header and -// offset and size of packet data in a payload view. -type PacketDescriptor struct { - Hdr buffer.Prependable - Off int - Size int -} - -// NewPacketDescriptors allocates a set of packet descriptors. -func NewPacketDescriptors(n int, hdrSize int) []PacketDescriptor { - buf := make([]byte, n*hdrSize) - hdrs := make([]PacketDescriptor, n) - for i := range hdrs { - hdrs[i].Hdr = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) - } - return hdrs -} - -// WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams) (int, *tcpip.Error) { +// WritePackets writes a list of n packets through the given route and returns +// the number of packets written. +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } - n, err := r.ref.ep.WritePackets(r, gso, hdrs, payload, params, r.Loop) + // WritePackets takes ownership of pkt, calculate length first. + numPkts := pkts.Len() + + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(hdrs) - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - payloadSize := 0 - for i := 0; i < n; i++ { - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdrs[i].Hdr.UsedLength())) - payloadSize += hdrs[i].Size + + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.Loop); err != nil { + // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Data.Size() + + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } @@ -233,6 +239,12 @@ func (r *Route) MTU() uint32 { return r.ref.ep.MTU() } +// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying +// network endpoint. +func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return r.ref.ep.NetworkProtocolNumber() +} + // Release frees all resources associated with the route. func (r *Route) Release() { if r.ref != nil { @@ -244,7 +256,9 @@ func (r *Route) Release() { // Clone Clone a route such that the original one can be released and the new // one will remain valid. func (r *Route) Clone() Route { - r.ref.incRef() + if r.ref != nil { + r.ref.incRef() + } return *r } @@ -269,3 +283,36 @@ func (r *Route) MakeLoopedRoute() Route { func (r *Route) Stack() *Stack { return r.ref.stack() } + +// IsOutboundBroadcast returns true if the route is for an outbound broadcast +// packet. +func (r *Route) IsOutboundBroadcast() bool { + // Only IPv4 has a notion of broadcast. + return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast +} + +// IsInboundBroadcast returns true if the route is for an inbound broadcast +// packet. +func (r *Route) IsInboundBroadcast() bool { + // Only IPv4 has a notion of broadcast. + if r.LocalAddress == header.IPv4Broadcast { + return true + } + + addr := r.ref.addrWithPrefix() + subnet := addr.Subnet() + return subnet.IsBroadcast(r.LocalAddress) +} + +// ReverseRoute returns new route with given source and destination address. +func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { + return Route{ + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + ref: r.ref, + Loop: r.Loop, + } +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 71e0618f4..814b3e94a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,19 +20,20 @@ package stack import ( + "bytes" "encoding/binary" "math" - "sync" + mathrand "math/rand" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" @@ -54,7 +55,7 @@ const ( // fakeNetNumber is used as a protocol number in tests. // // This constant should match fakeNetNumber in stack_test.go. - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 ) type forwardingFlag uint32 @@ -77,8 +78,7 @@ func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag { case header.IPv6ProtocolNumber: flag = forwardingIPv6 case fakeNetNumber: - // This network protocol number is used in stack_test to test - // packet forwarding. + // This network protocol number is used to test packet forwarding. flag = forwardingFake default: // We only support forwarding for IPv4 and IPv6. @@ -88,7 +88,7 @@ func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag { type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -109,6 +109,16 @@ type TCPCubicState struct { WEst float64 } +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +type TCPRACKState struct { + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool +} + // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. type TCPEndpointID struct { // LocalPort is the local port associated with the endpoint. @@ -248,6 +258,9 @@ type TCPSenderState struct { // Cubic holds the state related to CUBIC congestion control. Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. @@ -271,11 +284,11 @@ type RcvBufAutoTuneParams struct { // was started. MeasureTime time.Time - // CopiedBytes is the number of bytes copied to user space since + // CopiedBytes is the number of bytes copied to userspace since // this measure began. CopiedBytes int - // PrevCopiedBytes is the number of bytes copied to user space in + // PrevCopiedBytes is the number of bytes copied to userspace in // the previous RTT period. PrevCopiedBytes int @@ -382,6 +395,45 @@ type ResumableEndpoint interface { Resume(*Stack) } +// uniqueIDGenerator is a default unique ID generator. +type uniqueIDGenerator uint64 + +func (u *uniqueIDGenerator) UniqueID() uint64 { + return atomic.AddUint64((*uint64)(u), 1) +} + +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it will be passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -399,12 +451,17 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC + mu sync.RWMutex + nics map[tcpip.NICID]*NIC // forwarding contains the enable bits for packet forwarding for different // network protocols. - forwarding uint32 + forwarding struct { + sync.RWMutex + flag forwardingFlag + } + + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -424,7 +481,8 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. - tables iptables.IPTables + // TODO(gvisor.dev/issue/170): S/R this field. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -434,23 +492,62 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // portSeed is a one-time random value initialized at stack startup + // seed is a one-time random value initialized at stack startup // and is used to seed the TCP port picking on active connections // // TODO(gvisor.dev/issue/940): S/R this field. - portSeed uint32 + seed uint32 // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations + // nudConfigs is the default NUD configurations used by interfaces. + nudConfigs NUDConfigurations + // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. - // See the AutoGenIPv6LinkLocal field of Options for more details. + // to auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. autoGenIPv6LinkLocal bool // ndpDisp is the NDP event dispatcher that is used to send the netstack // integrator NDP related events. ndpDisp NDPDispatcher + + // nudDisp is the NUD event dispatcher that is used to send the netstack + // integrator NUD related events. + nudDisp NUDDispatcher + + // uniqueIDGenerator is a generator of unique identifiers. + uniqueIDGenerator UniqueID + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // tempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + tempIIDSeed []byte + + // forwarder holds the packets that wait for their link-address resolutions + // to complete, and forwards them when each resolution is done. + forwarder *forwardQueue + + // randomGenerator is an injectable pseudo random generator that can be + // used when a random number is required. + randomGenerator *mathrand.Rand + + // sendBufferSize holds the min/default/max send buffer sizes for + // endpoints other than TCP. + sendBufferSize SendBufferSizeOption + + // receiveBufferSize holds the min/default/max receive buffer sizes for + // endpoints other than TCP. + receiveBufferSize ReceiveBufferSizeOption +} + +// UniqueID is an abstract generator of unique identifiers. +type UniqueID interface { + UniqueID() uint64 } // Options contains optional Stack configuration. @@ -474,6 +571,9 @@ type Options struct { // stack (false). HandleLocal bool + // UniqueID is an optional generator of unique identifiers. + UniqueID UniqueID + // NDPConfigs is the default NDP configurations used by interfaces. // // By default, NDPConfigs will have a zero value for its @@ -481,13 +581,18 @@ type Options struct { // before assigning an address to a NIC. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determins whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. + // NUDConfigs is the default NUD configurations used by interfaces. + NUDConfigs NUDConfigurations + + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address - // Detection is enabled, an address will only be assigned if it - // successfully resolves. If it fails, no further attempt will be made - // to auto-generate an IPv6 link-local address. + // will be assigned right away, or at all. If Duplicate Address Detection + // is enabled, an address will only be assigned if it successfully resolves. + // If it fails, no further attempt will be made to auto-generate an IPv6 + // link-local address. // // The generated link-local address will follow RFC 4291 Appendix A // guidelines. @@ -497,9 +602,39 @@ type Options struct { // receive NDP related events. NDPDisp NDPDispatcher + // NUDDisp is the NUD event dispatcher that an integrator can provide to + // receive NUD related events. + NUDDisp NUDDispatcher + // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory + + // OpaqueIIDOpts hold the options for generating opaque interface + // identifiers (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // RandSource is an optional source to use to generate random + // numbers. If omitted it defaults to a Source seeded by the data + // returned by rand.Read(). + // + // RandSource must be thread-safe. + RandSource mathrand.Source + + // TempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + // + // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // and random from the perspective of other nodes on the network. It is + // recommended that the seed be a random byte buffer of at least + // header.IIDSize bytes to make sure that temporary SLAAC addresses are + // sufficiently random. It should follow minimum randomness requirements for + // security as outlined by RFC 4086. + // + // Note: using a nil value, the same seed across netstack program runs, or a + // seed that is too small would reduce randomness and increase predictability, + // defeating the purpose of temporary SLAAC addresses. + TempIIDSeed []byte } // TransportEndpointInfo holds useful information about a transport endpoint @@ -526,6 +661,51 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } +// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6 +// address and returns the network protocol number to be used to communicate +// with the specified address. It returns an error if the passed address is +// incompatible with the receiver. +// +// Preconditon: the parent endpoint mu must be held while calling this method. +func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.NetProto + switch len(addr.Addr) { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + if header.IsV4MappedAddress(addr.Addr) { + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == header.IPv4Any { + addr.Addr = "" + } + } + } + + switch len(e.ID.LocalAddress) { + case header.IPv4AddressSize: + if len(addr.Addr) == header.IPv6AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + case header.IPv6AddressSize: + if len(addr.Addr) == header.IPv4AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + } + } + + switch { + case netProto == e.NetProto: + case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: + if v6only { + return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + } + default: + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + + return addr, netProto, nil +} + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo // marker interface. func (*TransportEndpointInfo) IsEndpointInfo() {} @@ -546,24 +726,56 @@ func New(opts Options) *Stack { clock = &tcpip.StdClock{} } + if opts.UniqueID == nil { + opts.UniqueID = new(uniqueIDGenerator) + } + + randSrc := opts.RandSource + if randSrc == nil { + // Source provided by mathrand.NewSource is not thread-safe so + // we wrap it in a simple thread-safe version. + randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() + opts.NUDConfigs.resetInvalidFields() + s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), PortManager: ports.NewPortManager(), clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), - portSeed: generateRandUint32(), + seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, + nudConfigs: opts.NUDConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, + nudDisp: opts.NUDDisp, + opaqueIIDOpts: opts.OpaqueIIDOpts, + tempIIDSeed: opts.TempIIDSeed, + forwarder: newForwardQueue(), + randomGenerator: mathrand.New(randSrc), + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, + receiveBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, } // Add specified network protocols. @@ -590,6 +802,16 @@ func New(opts Options) *Stack { return s } +// newJob returns a tcpip.Job using the Stack clock. +func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} + +// UniqueID returns a unique identifier. +func (s *Stack) UniqueID() uint64 { + return s.uniqueIDGenerator.UniqueID() +} + // SetNetworkProtocolOption allows configuring individual protocol level // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value @@ -651,16 +873,17 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h } } -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (s *Stack) NowNanoseconds() int64 { - return s.clock.NowNanoseconds() +// Clock returns the Stack's clock for retrieving the current time and +// scheduling work. +func (s *Stack) Clock() tcpip.Clock { + return s.clock } // Stats returns a mutable copy of the current stats. @@ -673,30 +896,55 @@ func (s *Stack) Stats() tcpip.Stats { // SetForwarding enables or disables packet forwarding between NICs. func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) { + s.forwarding.Lock() + defer s.forwarding.Unlock() + + // If this stack does not support the protocol, do nothing. + if _, ok := s.networkProtocols[protocol]; !ok { + return + } + flag := getForwardingFlag(protocol) - for { - forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding)) - var newValue forwardingFlag + + // If the forwarding value for this protocol hasn't changed then do + // nothing. + if s.forwarding.flag&getForwardingFlag(protocol) != 0 == enable { + return + } + + var newValue forwardingFlag + if enable { + newValue = s.forwarding.flag | flag + } else { + newValue = s.forwarding.flag & ^flag + } + s.forwarding.flag = newValue + + // Enable or disable NDP for IPv6. + if protocol == header.IPv6ProtocolNumber { if enable { - newValue = forwarding | flag + for _, nic := range s.nics { + nic.becomeIPv6Router() + } } else { - newValue = forwarding & ^flag - } - if atomic.CompareAndSwapUint32(&s.forwarding, uint32(forwarding), uint32(newValue)) { - break + for _, nic := range s.nics { + nic.becomeIPv6Host() + } } } } // Forwarding returns if packet forwarding between NICs is enabled. func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - flag := getForwardingFlag(protocol) - forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding)) - return forwarding & flag != 0 + s.forwarding.RLock() + defer s.forwarding.RUnlock() + return s.forwarding.flag&getForwardingFlag(protocol) != 0 } // SetRouteTable assigns the route table to be used by this stack. It // specifies which NIC to use for given destination address ranges. +// +// This method takes ownership of the table. func (s *Stack) SetRouteTable(table []tcpip.Route) { s.mu.Lock() defer s.mu.Unlock() @@ -711,6 +959,13 @@ func (s *Stack) GetRouteTable() []tcpip.Route { return append([]tcpip.Route(nil), s.routeTable...) } +// AddRoute appends a route to the route table. +func (s *Stack) AddRoute(route tcpip.Route) { + s.mu.Lock() + defer s.mu.Unlock() + s.routeTable = append(s.routeTable, route) +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] @@ -751,9 +1006,32 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) } -// createNIC creates a NIC with the provided id and link-layer endpoint, and -// optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { +// NICContext is an opaque pointer used to store client-supplied NIC metadata. +type NICContext interface{} + +// NICOptions specifies the configuration of a NIC as it is being created. +// The zero value creates an enabled, unnamed NIC. +type NICOptions struct { + // Name specifies the name of the NIC. + Name string + + // Disabled specifies whether to avoid calling Attach on the passed + // LinkEndpoint. + Disabled bool + + // Context specifies user-defined data that will be returned in stack.NICInfo + // for the NIC. Clients of this library can use it to add metadata that + // should be tracked alongside a NIC, to avoid having to keep a + // map[tcpip.NICID]metadata mirroring stack.Stack's nic map. + Context NICContext +} + +// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and +// NICOptions. See the documentation on type NICOptions for details on how +// NICs can be configured. +// +// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. +func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -762,44 +1040,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, name, ep, loopback) + // Make sure name is unique, unless unnamed. + if opts.Name != "" { + for _, n := range s.nics { + if n.Name() == opts.Name { + return tcpip.ErrDuplicateNICID + } + } + } + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n - if enabled { + if !opts.Disabled { return n.enable() } return nil } -// CreateNIC creates a NIC with the provided id and link-layer endpoint. +// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls +// LinkEndpoint.Attach to bind ep with a NetworkDispatcher. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, true, false) + return s.CreateNICWithOptions(id, ep, NICOptions{}) } -// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, -// and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, false) -} - -// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer -// endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, true) -} - -// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, -// but leave it disable. Stack.EnableNIC must be called before the link-layer -// endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, false, false) -} - -// CreateDisabledNamedNIC is a combination of CreateNamedNIC and -// CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, false, false) +// GetNICByName gets the NIC specified by name. +func (s *Stack) GetNICByName(name string) (*NIC, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, nic := range s.nics { + if nic.Name() == name { + return nic, true + } + } + return nil, false } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -808,36 +1082,72 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { + nic, ok := s.nics[id] + if !ok { return tcpip.ErrUnknownNICID } return nic.enable() } -// CheckNIC checks if a NIC is usable. -func (s *Stack) CheckNIC(id tcpip.NICID) bool { +// DisableNIC disables the given NIC. +func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { s.mu.RLock() + defer s.mu.RUnlock() + nic, ok := s.nics[id] - s.mu.RUnlock() - if ok { - return nic.linkEP.IsAttached() + if !ok { + return tcpip.ErrUnknownNICID } - return false + + return nic.disable() } -// NICSubnets returns a map of NICIDs to their associated subnets. -func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { +// CheckNIC checks if a NIC is usable. +func (s *Stack) CheckNIC(id tcpip.NICID) bool { s.mu.RLock() defer s.mu.RUnlock() - nics := map[tcpip.NICID][]tcpip.Subnet{} + nic, ok := s.nics[id] + if !ok { + return false + } - for id, nic := range s.nics { - nics[id] = append(nics[id], nic.AddressRanges()...) + return nic.enabled() +} + +// RemoveNIC removes NIC and all related routes from the network stack. +func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.removeNICLocked(id) +} + +// removeNICLocked removes NIC and all related routes from the network stack. +// +// s.mu must be locked. +func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID } - return nics + delete(s.nics, id) + + // Remove routes in-place. n tracks the number of routes written. + n := 0 + for i, r := range s.routeTable { + s.routeTable[i] = tcpip.Route{} + if r.NIC != id { + // Keep this route. + s.routeTable[n] = r + n++ + } + } + + s.routeTable = s.routeTable[:n] + + return nic.remove() } // NICInfo captures the name and addresses assigned to a NIC. @@ -853,6 +1163,23 @@ type NICInfo struct { MTU uint32 Stats NICStats + + // Context is user-supplied data optionally supplied in CreateNICWithOptions. + // See type NICOptions for more details. + Context NICContext + + // ARPHardwareType holds the ARP Hardware type of the NIC. This is the + // value sent in haType field of an ARP Request sent by this NIC and the + // value expected in the haType field of an ARP response. + ARPHardwareType header.ARPHardwareType +} + +// HasNIC returns true if the NICID is defined in the stack. +func (s *Stack) HasNIC(id tcpip.NICID) bool { + s.mu.RLock() + _, ok := s.nics[id] + s.mu.RUnlock() + return ok } // NICInfo returns a map of NICIDs to their associated information. @@ -864,9 +1191,9 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.linkEP.IsAttached(), + Running: nic.enabled(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0, + Loopback: nic.isLoopback(), } nics[id] = NICInfo{ Name: nic.name, @@ -875,6 +1202,8 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, + Context: nic.context, + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics @@ -936,35 +1265,6 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return nic.AddAddress(protocolAddress, peb) } -// AddAddressRange adds a range of addresses to the specified NIC. The range is -// given by a subnet address, and all addresses contained in the subnet are -// used except for the subnet address itself and the subnet's broadcast -// address. -func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - nic.AddAddressRange(protocol, subnet) - return nil - } - - return tcpip.ErrUnknownNICID -} - -// RemoveAddressRange removes the range of addresses from the specified NIC. -func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - nic.RemoveAddressRange(subnet) - return nil - } - - return tcpip.ErrUnknownNICID -} - // RemoveAddress removes an existing network-layer address from the specified // NIC. func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { @@ -991,9 +1291,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { return nics } -// GetMainNICAddress returns the first primary address and prefix for the given -// NIC and protocol. Returns an error if the NIC doesn't exist and an empty -// value if the NIC doesn't have a primary address for the given protocol. +// GetMainNICAddress returns the first non-deprecated primary address and prefix +// for the given NIC and protocol. If no non-deprecated primary address exists, +// a deprecated primary address and prefix will be returned. Returns an error if +// the NIC doesn't exist and an empty value if the NIC doesn't have a primary +// address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1003,17 +1305,12 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - for _, a := range nic.PrimaryAddresses() { - if a.Protocol == protocol { - return a.AddressWithPrefix, nil - } - } - return tcpip.AddressWithPrefix{}, nil + return nic.primaryAddress(protocol), nil } -func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { +func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { if len(localAddr) == 0 { - return nic.primaryEndpoint(netProto) + return nic.primaryEndpoint(netProto, remoteAddr) } return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } @@ -1024,13 +1321,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isBroadcast := remoteAddr == header.IPv4Broadcast + isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok { - if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil + if nic, ok := s.nics[id]; ok && nic.enabled() { + if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { + return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } } else { @@ -1038,18 +1335,25 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok { - if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { + if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { + if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. - remoteAddr = ref.ep.ID().LocalAddress + remoteAddr = ref.address() } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback) - if needRoute { - r.NextHop = route.Gateway + r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) + r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) + + if len(route.Gateway) > 0 { + if needRoute { + r.NextHop = route.Gateway + } + } else if r.directedBroadcast { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } + return r, nil } } @@ -1073,13 +1377,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool // CheckLocalAddress determines if the given local address exists, and if it // does, returns the id of the NIC it's bound to. Returns 0 if the address // does not exist. -func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { +func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { s.mu.RLock() defer s.mu.RUnlock() // If a NIC is specified, we try to find the address there only. - if nicid != 0 { - nic := s.nics[nicid] + if nicID != 0 { + nic := s.nics[nicID] if nic == nil { return 0 } @@ -1138,35 +1442,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { } // AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.add(fullAddr, linkAddr) // TODO: provide a way for a transport endpoint to receive a signal // that AddLinkAddress for a particular address has been called. } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() - nic := s.nics[nicid] + nic := s.nics[nicID] if nic == nil { s.mu.RUnlock() return "", nil, tcpip.ErrUnknownNICID } s.mu.RUnlock() - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } // RemoveWaker implements LinkAddressCache.RemoveWaker. -func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { +func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { s.mu.RLock() defer s.mu.RUnlock() - if nic := s.nics[nicid]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + if nic := s.nics[nicID]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.removeWaker(fullAddr, waker) } } @@ -1175,14 +1479,45 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) +} + +// CheckRegisterTransportEndpoint checks if an endpoint can be registered with +// the stack transport dispatcher. +func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) +} + +// StartTransportEndpointCleanup removes the endpoint with the given id from +// the stack transport dispatcher. It also transitions it to the cleanup stage. +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + s.mu.Lock() + defer s.mu.Unlock() + + s.cleanupEndpoints[ep] = struct{}{} + + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) +} + +// CompleteTransportEndpointCleanup removes the endpoint from the cleanup +// stage. +func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { + s.mu.Lock() + delete(s.cleanupEndpoints, ep) + s.mu.Unlock() +} + +// FindTransportEndpoint finds an endpoint that most closely matches the provided +// id. If no endpoint is found it returns nil. +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, r) } // RegisterRawTransportEndpoint registers the given endpoint with the stack @@ -1206,6 +1541,81 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { s.mu.Unlock() } +// RegisteredEndpoints returns all endpoints which are currently registered. +func (s *Stack) RegisteredEndpoints() []TransportEndpoint { + s.mu.Lock() + defer s.mu.Unlock() + var es []TransportEndpoint + for _, e := range s.demux.protocol { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// CleanupEndpoints returns endpoints currently in the cleanup state. +func (s *Stack) CleanupEndpoints() []TransportEndpoint { + s.mu.Lock() + es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) + for e := range s.cleanupEndpoints { + es = append(es, e) + } + s.mu.Unlock() + return es +} + +// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful +// for restoring a stack after a save. +func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { + s.mu.Lock() + for _, e := range es { + s.cleanupEndpoints[e] = struct{}{} + } + s.mu.Unlock() +} + +// Close closes all currently registered transport endpoints. +// +// Endpoints created or modified during this call may not get closed. +func (s *Stack) Close() { + for _, e := range s.RegisteredEndpoints() { + e.Abort() + } + for _, p := range s.transportProtocols { + p.proto.Close() + } + for _, p := range s.networkProtocols { + p.Close() + } +} + +// Wait waits for all transport and link endpoints to halt their worker +// goroutines. +// +// Endpoints created or modified during this call may not get waited on. +// +// Note that link endpoints must be stopped via an implementation specific +// mechanism. +func (s *Stack) Wait() { + for _, e := range s.RegisteredEndpoints() { + e.Wait() + } + for _, e := range s.CleanupEndpoints() { + e.Wait() + } + for _, p := range s.transportProtocols { + p.proto.Wait() + } + for _, p := range s.networkProtocols { + p.Wait() + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, n := range s.nics { + n.linkEP.Wait() + } +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { @@ -1280,9 +1690,9 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip // WritePacket writes data directly to the specified NIC. It adds an ethernet // header based on the arguments. -func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() - nic, ok := s.nics[nicid] + nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice @@ -1296,10 +1706,10 @@ func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto t } fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) fakeHeader.Encode(ðFields) - ethHeader := buffer.View(fakeHeader).ToVectorisedView() - ethHeader.Append(payload) + vv := buffer.View(fakeHeader).ToVectorisedView() + vv.Append(payload) - if err := nic.linkEP.WriteRawPacket(ethHeader); err != nil { + if err := nic.linkEP.WriteRawPacket(vv); err != nil { return err } @@ -1308,9 +1718,9 @@ func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto t // WriteRawPacket writes data directly to the specified NIC without adding any // headers. -func (s *Stack) WriteRawPacket(nicid tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() - nic, ok := s.nics[nicid] + nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice @@ -1403,14 +1813,21 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC return tcpip.ErrUnknownNICID } -// IPTables returns the stack's iptables. -func (s *Stack) IPTables() iptables.IPTables { - return s.tables +// IsInGroup returns true if the NIC with ID nicID has joined the multicast +// group multicastAddr. +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.isInGroup(multicastAddr), nil + } + return false, tcpip.ErrUnknownNICID } -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt iptables.IPTables) { - s.tables = ipt +// IPTables returns the stack's iptables. +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent @@ -1489,16 +1906,66 @@ func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip } nic.setNDPConfigs(c) + return nil +} + +// NUDConfigurations gets the per-interface NUD configurations. +func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { + s.mu.RLock() + nic, ok := s.nics[id] + s.mu.RUnlock() + + if !ok { + return NUDConfigurations{}, tcpip.ErrUnknownNICID + } + + return nic.NUDConfigs() +} + +// SetNUDConfigurations sets the per-interface NUD configurations. +// +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[id] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.setNUDConfigs(c) +} + +// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement +// message that it needs to handle. +func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.handleNDPRA(ip, ra) return nil } -// PortSeed returns a 32 bit value that can be used as a seed value for port -// picking. +// Seed returns a 32 bit value that can be used as a seed value for port +// picking, ISN generation etc. // // NOTE: The seed is generated once during stack initialization only. -func (s *Stack) PortSeed() uint32 { - return s.portSeed +func (s *Stack) Seed() uint32 { + return s.seed +} + +// Rand returns a reference to a pseudo random generator that can be used +// to generate random numbers as required. +func (s *Stack) Rand() *mathrand.Rand { + return s.randomGenerator } func generateRandUint32() uint32 { @@ -1508,3 +1975,49 @@ func generateRandUint32() uint32 { } return binary.LittleEndian.Uint32(b) } + +func generateRandInt64() int64 { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + panic(err) + } + buf := bytes.NewReader(b) + var v int64 + if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { + panic(err) + } + return v +} + +// FindNetworkEndpoint returns the network endpoint for the given address. +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + for _, nic := range s.nics { + id := NetworkEndpointID{address} + + if ref, ok := nic.mu.endpoints[id]; ok { + nic.mu.RLock() + defer nic.mu.RUnlock() + + // An endpoint with this id exists, check if it can be + // used and return it. + return ref.ep, nil + } + } + return nil, tcpip.ErrBadAddress +} + +// FindNICNameFromID returns the name of the nic for the given NICID. +func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] + if !ok { + return "" + } + + return nic.Name() +} diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go new file mode 100644 index 000000000..0b093e6c5 --- /dev/null +++ b/pkg/tcpip/stack/stack_options.go @@ -0,0 +1,106 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "gvisor.dev/gvisor/pkg/tcpip" + +const ( + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4 KiB + + // DefaultBufferSize is the default size of the send/recv buffer for a + // transport endpoint. + DefaultBufferSize = 212 << 10 // 212 KiB + + // DefaultMaxBufferSize is the default maximum permitted size of a + // send/receive buffer. + DefaultMaxBufferSize = 4 << 20 // 4 MiB +) + +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +// SetOption allows setting stack wide options. +func (s *Stack) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SendBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.sendBufferSize = v + s.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.receiveBufferSize = v + s.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option allows retrieving stack wide options. +func (s *Stack) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SendBufferSizeOption: + s.mu.RLock() + *v = s.sendBufferSize + s.mu.RUnlock() + return nil + + case *ReceiveBufferSizeOption: + s.mu.RLock() + *v = s.receiveBufferSize + s.mu.RUnlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index ef3d1beb0..f168be402 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,18 +21,24 @@ import ( "bytes" "fmt" "math" + "net" "sort" - "strings" "testing" "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( @@ -51,6 +57,10 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -61,9 +71,7 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { - nicid tcpip.NICID - id stack.NetworkEndpointID - prefixLen int + nicID tcpip.NICID proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher ep stack.LinkEndpoint @@ -74,43 +82,35 @@ func (f *fakeNetworkEndpoint) MTU() uint32 { } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicid -} - -func (f *fakeNetworkEndpoint) PrefixLen() int { - return f.prefixLen + return f.nicID } func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { - return &f.id -} - -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - - // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) + f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() - if len(nb) < fakeNetHeaderLen { + if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + pkt.Data.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket( + tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + fakeNetNumber, + tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -125,37 +125,37 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return f.proto.Number() +} + +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := hdr.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) - - if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) - } - if loop&stack.PacketOut == 0 { + hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + hdr[dstAddrOffset] = r.RemoteAddress[0] + hdr[srcAddrOffset] = r.LocalAddress[0] + hdr[protocolNumberOffset] = byte(params.Protocol) + + if r.Loop&stack.PacketLoop != 0 { + f.HandlePacket(r, pkt) + } + if r.Loop&stack.PacketOut == 0 { return nil } - return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -197,18 +197,16 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { } func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { return &fakeNetworkEndpoint{ - nicid: nicid, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, + nicID: nicID, proto: f, dispatcher: dispatcher, ep: ep, - }, nil + } } func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { @@ -233,10 +231,53 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +// Close implements TransportProtocol.Close. +func (*fakeNetworkProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeNetworkProtocol) Wait() {} + +// Parse implements TransportProtocol.Parse. +func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) + if !ok { + return 0, false, false + } + return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true +} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } +// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify +// that LinkEndpoint.Attach was called. +type linkEPWithMockedAttach struct { + stack.LinkEndpoint + attached bool +} + +// Attach implements stack.LinkEndpoint.Attach. +func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { + l.LinkEndpoint.Attach(d) + l.attached = d != nil +} + +func (l *linkEPWithMockedAttach) isAttached() bool { + return l.attached +} + +// Checks to see if list contains an address. +func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { + for _, i := range list { + if i == item { + return true + } + } + + return false +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -261,8 +302,10 @@ func TestNetworkReceive(t *testing.T) { buf := buffer.NewView(30) // Make sure packet with wrong address is not delivered. - buf[0] = 3 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + buf[dstAddrOffset] = 3 + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -271,8 +314,10 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to first endpoint. - buf[0] = 1 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + buf[dstAddrOffset] = 1 + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -281,8 +326,10 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to second endpoint. - buf[0] = 2 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + buf[dstAddrOffset] = 2 + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -291,7 +338,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -301,7 +350,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -320,8 +371,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro } func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + })) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -376,7 +429,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -493,6 +548,340 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr } } +// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to +// a NetworkDispatcher when the NIC is created. +func TestAttachToLinkEndpointImmediately(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + nicOpts stack.NICOptions + }{ + { + name: "Create enabled NIC", + nicOpts: stack.NICOptions{Disabled: false}, + }, + { + name: "Create disabled NIC", + nicOpts: stack.NICOptions{Disabled: true}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + + if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") + } + }) + } +} + +func TestDisableUnknownNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + } +} + +func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := loopback.New() + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + checkNIC := func(enabled bool) { + t.Helper() + + allNICInfo := s.NICInfo() + nicInfo, ok := allNICInfo[nicID] + if !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } else if nicInfo.Flags.Running != enabled { + t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) + } + + if got := s.CheckNIC(nicID); got != enabled { + t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) + } + } + + // NIC should initially report itself as disabled. + checkNIC(false) + + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + checkNIC(true) + + // If the NIC is not reporting a correct enabled status, we cannot trust the + // next check so end the test here. + if t.Failed() { + t.FailNow() + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + checkNIC(false) +} + +func TestRemoveUnknownNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + } +} + +func TestRemoveNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") + } + + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) + } + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) + } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } +} + +func TestRouteWithDownNIC(t *testing.T) { + tests := []struct { + name string + downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + }{ + { + name: "Disabled NIC", + downFn: (*stack.Stack).DisableNIC, + upFn: (*stack.Stack).EnableNIC, + }, + + // Once a NIC is removed, it cannot be brought up. + { + name: "Removed NIC", + downFn: (*stack.Stack).RemoveNIC, + }, + } + + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + const addr1 = tcpip.Address("\x01") + const addr2 = tcpip.Address("\x02") + const nic1Dst = tcpip.Address("\x05") + const nic2Dst = tcpip.Address("\x06") + + setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + return s, ep1, ep2 + } + + // Tests that routes through a down NIC are not used when looking up a route + // for a destination. + t.Run("Find", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, _, _ := setup(t) + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Bringing NIC1 down should result in no routes to odd addresses. Routes to + // even addresses should continue to be available as NIC2 is still up. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Bringing NIC2 down should result in no routes to even addresses. No + // route should be available to any address as routes to odd addresses + // were made unavailable by bringing NIC1 down above. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + if upFn := test.upFn; upFn != nil { + // Bringing NIC1 up should make routes to odd addresses available + // again. Routes to even addresses should continue to be unavailable + // as NIC2 is still down. + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + } + }) + } + }) + + // Tests that writing a packet using a Route through a down NIC fails. + t.Run("WritePacket", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, ep1, ep2 := setup(t) + + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() + + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() + + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } + + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use NIC1 after being brought down should fail. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use NIC2 after being brought down should fail. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + if upFn := test.upFn; upFn != nil { + // Writes with Routes that use NIC1 after being brought up should + // succeed. + // + // TODO(gvisor.dev/issue/1491): Should we instead completely + // invalidate all Routes that were bound to a NIC that was brought + // down at some point? + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + } + }) + } + }) +} + func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has @@ -602,7 +991,7 @@ func TestAddressRemoval(t *testing.T) { buf := buffer.NewView(30) // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -652,7 +1041,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) @@ -671,11 +1060,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } } -func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { +func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { t.Helper() - info, ok := s.NICInfo()[nicid] + info, ok := s.NICInfo()[nicID] if !ok { - t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + t.Fatalf("NICInfo() failed to find nicID=%d", nicID) } if len(addr) == 0 { // No address given, verify that there is no address assigned to the NIC. @@ -708,7 +1097,7 @@ func TestEndpointExpiration(t *testing.T) { localAddrByte byte = 0x01 remoteAddr tcpip.Address = "\x03" noAddr tcpip.Address = "" - nicid tcpip.NICID = 1 + nicID tcpip.NICID = 1 ) localAddr := tcpip.Address([]byte{localAddrByte}) @@ -720,7 +1109,7 @@ func TestEndpointExpiration(t *testing.T) { }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -734,16 +1123,16 @@ func TestEndpointExpiration(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte if promiscuous { - if err := s.SetPromiscuousMode(nicid, true); err != nil { + if err := s.SetPromiscuousMode(nicID, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } } if spoofing { - if err := s.SetSpoofing(nicid, true); err != nil { + if err := s.SetSpoofing(nicID, true); err != nil { t.Fatal("SetSpoofing failed:", err) } } @@ -751,7 +1140,7 @@ func TestEndpointExpiration(t *testing.T) { // 1. No Address yet, send should only work for spoofing, receive for // promiscuous mode. //----------------------- - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -766,20 +1155,20 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -794,10 +1183,10 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -815,10 +1204,10 @@ func TestEndpointExpiration(t *testing.T) { // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -834,10 +1223,10 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) testSend(t, r, ep, nil) @@ -845,17 +1234,17 @@ func TestEndpointExpiration(t *testing.T) { // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -897,7 +1286,7 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. @@ -1071,19 +1460,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } // If the NIC doesn't exist, it won't work. if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } } @@ -1109,12 +1498,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { } nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } // Set the initial route table. @@ -1129,10 +1518,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // When an interface is given, the route for a broadcast goes through it. r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } // When an interface is not given, it consults the route table. @@ -1254,149 +1643,6 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Add a range of addresses, then check that a packet is delivered. -func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[0] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { - t.Helper() - - // Loop over all addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - - addrBytes := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - addr := tcpip.Address(addrBytes) - wantNicID := nicID - // The subnet and broadcast addresses are skipped. - if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { - wantNicID = 0 - } - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) - } - addrBytes[0]++ - } - - // Trying the next address should always fail since it is outside the range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) - } -} - -// Set a range of addresses, then remove it again, and check at each step that -// CheckLocalAddress returns the correct NIC for each address or zero if not -// existent. -func TestCheckLocalAddressForSubnet(t *testing.T) { - const nicID tcpip.NICID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) - - if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) - - if err := s.RemoveAddressRange(nicID, subnet); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) -} - -// Set a range of addresses, then send a packet to a destination outside the -// range and then check it doesn't get delivered. -func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[0] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - func TestNetworkOptions(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, @@ -1440,56 +1686,6 @@ func TestNetworkOptions(t *testing.T) { } } -func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { - ranges, ok := s.NICAddressRanges()[id] - if !ok { - return false - } - for _, r := range ranges { - if r == addrRange { - return true - } - } - return false -} - -func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addr := tcpip.Address("\x01\x01\x01\x01") - mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - addrRange, err := tcpip.NewSubnet(addr, mask) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.RemoveAddressRange(1, addrRange); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } -} - func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for _, addrLen := range []int{4, 16} { t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { @@ -1648,12 +1844,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } func TestAddAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1661,7 +1857,7 @@ func TestAddAddress(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) for _, addrLen := range []int{4, 16} { address := addrGen.next(addrLen) - if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { t.Fatalf("AddAddress(address=%s) failed: %s", address, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1670,17 +1866,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1697,24 +1893,24 @@ func TestAddProtocolAddress(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) } expectedAddresses = append(expectedAddresses, protocolAddress) } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1725,7 +1921,7 @@ func TestAddAddressWithOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1735,17 +1931,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1764,7 +1960,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) } expectedAddresses = append(expectedAddresses, protocolAddress) @@ -1772,10 +1968,95 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } +func TestCreateNICWithOptions(t *testing.T) { + type callArgsAndExpect struct { + nicID tcpip.NICID + opts stack.NICOptions + err *tcpip.Error + } + + tests := []struct { + desc string + calls []callArgsAndExpect + }{ + { + desc: "DuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth1"}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth2"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "DuplicateName", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "lo"}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{Name: "lo"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "Unnamed", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{}, + err: nil, + }, + }, + }, + { + desc: "UnnamedDuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + s := stack.New(stack.Options{}) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + for _, call := range test.calls { + if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { + t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) + } + } + }) + } +} + func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, @@ -1798,7 +2079,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1823,150 +2106,386 @@ func TestNICStats(t *testing.T) { } func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(fakeNetNumber, true) + const nicID1 = 1 + const nicID2 = 2 + const dstAddr = tcpip.Address("\x03") - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + tests := []struct { + name string + headerLen uint16 + }{ + { + name: "Zero header length", + }, + { + name: "Non-zero header length", + headerLen: 16, + }, } - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + s.SetForwarding(fakeNetNumber, true) + + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) + } + + ep2 := channelLinkWithHeaderLength{ + Endpoint: channel.New(10, defaultMTU, ""), + headerLength: test.headerLen, + } + if err := s.CreateNIC(nicID2, &ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) + } + + // Route all packets to dstAddr to NIC 2. + { + subnet, err := tcpip.NewSubnet(dstAddr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) + } + + // Send a packet to dstAddr. + buf := buffer.NewView(30) + buf[dstAddrOffset] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not forwarded") + } + + // Test that the link's MaxHeaderLength is honoured. + if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { + t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) + } + + // Test that forwarding increments Tx stats correctly. + if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + } + + if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + }) } +} - // Route all packets to address 3 to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) +// TestNICContextPreservation tests that you can read out via stack.NICInfo the +// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. +func TestNICContextPreservation(t *testing.T) { + var ctx *int + tests := []struct { + name string + opts stack.NICOptions + want stack.NICContext + }{ + { + "context_set", + stack.NICOptions{Context: ctx}, + ctx, + }, + { + "context_not_set", + stack.NICOptions{}, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + id := tcpip.NICID(1) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { + t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) + } + nicinfos := s.NICInfo() + nicinfo, ok := nicinfos[id] + if !ok { + t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) + } + if got, want := nicinfo.Context == test.want, true; got != want { + t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + } + }) } +} - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) +// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local +// addresses. +func TestNICAutoGenLinkLocalAddr(t *testing.T) { + const nicID = 1 - select { - case <-ep2.C: - default: - t.Fatal("Packet not forwarded") + var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte + n, err := rand.Read(secretKey[:]) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) } - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + nicNameFunc := func(_ tcpip.NICID, name string) string { + return name } -} -// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// (or lack there-of if disabled (default)). Note, DAD will be disabled in -// these tests. -func TestNICAutoGenAddr(t *testing.T) { tests := []struct { - name string - autoGen bool - linkAddr tcpip.LinkAddress - shouldGen bool + name string + nicName string + autoGen bool + linkAddr tcpip.LinkAddress + iidOpts stack.OpaqueInterfaceIdentifierOptions + shouldGen bool + expectedAddr tcpip.Address }{ { - "Disabled", - false, - linkAddr1, - false, + name: "Disabled", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + shouldGen: false, + }, + { + name: "Disabled without OIID options", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:], + }, + shouldGen: false, + }, + + // Tests for EUI64 based addresses. + { + name: "EUI64 Enabled", + autoGen: true, + linkAddr: linkAddr1, + shouldGen: true, + expectedAddr: header.LinkLocalAddr(linkAddr1), + }, + { + name: "EUI64 Empty MAC", + autoGen: true, + shouldGen: false, }, { - "Enabled", - true, - linkAddr1, - true, + name: "EUI64 Invalid MAC", + autoGen: true, + linkAddr: "\x01\x02\x03", + shouldGen: false, }, { - "Nil MAC", - true, - tcpip.LinkAddress([]byte(nil)), - false, + name: "EUI64 Multicast MAC", + autoGen: true, + linkAddr: "\x01\x02\x03\x04\x05\x06", + shouldGen: false, }, { - "Empty MAC", - true, - tcpip.LinkAddress(""), - false, + name: "EUI64 Unspecified MAC", + autoGen: true, + linkAddr: "\x00\x00\x00\x00\x00\x00", + shouldGen: false, }, + + // Tests for Opaque IID based addresses. + { + name: "OIID Enabled", + nicName: "nic1", + autoGen: true, + linkAddr: linkAddr1, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), + }, + // These are all cases where we would not have generated a + // link-local address if opaque IIDs were disabled. { - "Invalid MAC", - true, - tcpip.LinkAddress("\x01\x02\x03"), - false, + name: "OIID Empty MAC and empty nicName", + autoGen: true, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:1], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]), }, { - "Multicast MAC", - true, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - false, + name: "OIID Invalid MAC", + nicName: "test", + autoGen: true, + linkAddr: "\x01\x02\x03", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:2], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]), }, { - "Unspecified MAC", - true, - tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), - false, + name: "OIID Multicast MAC", + nicName: "test2", + autoGen: true, + linkAddr: "\x01\x02\x03\x04\x05\x06", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:3], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]), + }, + { + name: "OIID Unspecified MAC and nil SecretKey", + nicName: "test3", + autoGen: true, + linkAddr: "\x00\x00\x00\x00\x00\x00", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, } - if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when - // test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by - // default. - opts.AutoGenIPv6LinkLocal = true + e := channel.New(0, 1280, test.linkAddr) + s := stack.New(opts) + nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - e := channel.New(10, 1280, test.linkAddr) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + // A new disabled NIC should not have any address, even if auto generation + // was enabled. + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) } - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + // Enabling the NIC should attempt auto-generation of a link-local + // address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } + var expectedMainAddr tcpip.AddressWithPrefix if test.shouldGen { - // Should have auto-generated an address and - // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + expectedMainAddr = tcpip.AddressWithPrefix{ + Address: test.expectedAddr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + } + + // Should have auto-generated an address and resolved immediately (DAD + // is disabled). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } } else { // Should not have auto-generated an address. - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address") + default: } } + + gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if gotMainAddr != expectedMainAddr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) + } + }) + } +} + +// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are +// not auto-generated for loopback NICs. +func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { + const nicID = 1 + const nicName = "nicName" + + tests := []struct { + name string + opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + }{ + { + name: "IID From MAC", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + }, + { + name: "Opaque IID", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + } + + e := loopback.New() + s := stack.New(opts) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + } }) } } @@ -1974,6 +2493,8 @@ func TestNICAutoGenAddr(t *testing.T) { // TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 // link-local addresses will only be assigned after the DAD process resolves. func TestNICAutoGenAddrDoesDAD(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -1985,20 +2506,20 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { NDPDisp: &ndpDisp, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } // Address should not be considered bound to the // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } linkLocalAddr := header.LinkLocalAddr(linkAddr1) @@ -2012,25 +2533,16 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) - } - if e.addr != linkLocalAddr { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } } @@ -2078,7 +2590,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { - t.Fatalf("NewSubnet failed:", err) + t.Fatalf("NewSubnet failed: %v", err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } @@ -2092,11 +2604,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // permanentExpired kind. r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) if err != nil { - t.Fatal("FindRoute failed:", err) + t.Fatalf("FindRoute failed: %v", err) } defer r.Release() if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } // @@ -2108,7 +2620,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } @@ -2116,7 +2628,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // make sure the new peb was respected. // (The address should just be promoted now). if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[1].ProtocolAddresses { @@ -2149,11 +2661,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // GetMainNICAddress; else, our original address // should be returned. if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } addr, err = s.GetMainNICAddress(1, fakeNetNumber) if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + t.Fatalf("s.GetMainNICAddress failed: %v", err) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { @@ -2169,3 +2681,858 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { } } } + +func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { + const ( + linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + nicID = 1 + lifetimeSeconds = 9999 + ) + + prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) + + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address + tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address + + // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. + tests := []struct { + name string + slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix + nicAddrs []tcpip.Address + slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix + connectAddr tcpip.Address + expectedLocalAddr tcpip.Address + }{ + // Test Rule 1 of RFC 6724 section 5. + { + name: "Same Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Same Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test Rule 2 of RFC 6724 section 5. + { + name: "Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred for link local multicast (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred for link local multicast (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test Rule 7 of RFC 6724 section 5. + { + name: "Temp Global most preferred (last address)", + slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: tempGlobalAddr1, + }, + { + name: "Temp Global most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, + connectAddr: globalAddr2, + expectedLocalAddr: tempGlobalAddr1, + }, + + // Test returning the endpoint that is closest to the front when + // candidate addresses are "equal" from the perspective of RFC 6724 + // section 5. + { + name: "Unique Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Link Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local for Unique Local", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Temp Global for Global", + slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, + slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, + connectAddr: globalAddr1, + expectedLocalAddr: tempGlobalAddr2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDispatcher{}, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + + if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) + } + + for _, a := range test.nicAddrs { + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { + t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + } + } + + if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) { + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) + } + + if t.Failed() { + t.FailNow() + } + + if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) + } + }) + } +} + +func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { + const nicID = 1 + broadcastAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } + } + + // Enabling the NIC should add the IPv4 broadcast address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if !containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) + } + } + + // Disabling the NIC should remove the IPv4 broadcast address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } + } +} + +// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 +// address after leaving its solicited node multicast address does not result in +// an error. +func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + } + + // The NIC should have joined addr1's solicited node multicast address. + snmc := header.SolicitedNodeAddr(addr1) + in, err := s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if !in { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) + } + + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) + } + in, err = s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if in { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) + } + + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + } +} + +func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + name: "IPv6 All-Nodes", + proto: header.IPv6ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + }, + { + name: "IPv4 All-Systems", + proto: header.IPv4ProtocolNumber, + addr: header.IPv4AllSystems, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + // Should not be in the multicast group yet because the NIC has not been + // enabled yet. + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // The multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // Leaving the group before disabling the NIC should not cause an error. + if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { + t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + }) + } +} + +// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC +// was disabled have DAD performed on them when the NIC is enabled. +func TestDoDADWhenNICEnabled(t *testing.T) { + const dadTransmits = 1 + const retransmitTimer = time.Second + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPDisp: &ndpDisp, + } + + e := channel.New(dadTransmits, 1280, linkAddr1) + s := stack.New(opts) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: llAddr1, + PrefixLen: 128, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + // Address should be in the list of all addresses. + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + + // Address should be tentative so it should not be a main address. + got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); got != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + } + + // Enabling the NIC should start DAD for the address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); got != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if got != addr.AddressWithPrefix { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + } + + // Enabling the NIC again should be a no-op. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if got != addr.AddressWithPrefix { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + } +} + +func TestStackReceiveBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + rs stack.ReceiveBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.rs); err != tc.err { + t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) + } + var rs stack.ReceiveBufferSizeOption + if tc.err == nil { + if err := s.Option(&rs); err != nil { + t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) + } + if got, want := rs, tc.rs; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestStackSendBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + ss stack.SendBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.ss); err != tc.err { + t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + } + var ss stack.SendBufferSizeOption + if tc.err == nil { + if err := s.Option(&ss); err != nil { + t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) + } + if got, want := ss, tc.ss; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID1 = 1 + ) + + defaultAddr := tcpip.AddressWithPrefix{ + Address: header.IPv4Any, + PrefixLen: 0, + } + defaultSubnet := defaultAddr.Subnet() + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedRoute stack.Route + }{ + // Broadcast to a locally attached subnet populates the broadcast MAC. + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: ipv4SubnetBcast, + RemoteLinkAddress: header.EthernetBroadcastAddress, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /31 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix31.Address, + RemoteAddress: ipv4Subnet31Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /32 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix32.Address, + RemoteAddress: ipv4Subnet32Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv6Addr.Address, + RemoteAddress: ipv6SubnetBcast, + NetProto: header.IPv6ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a remote subnet in the route table is send to the next-hop + // gateway. + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to an unknown subnet follows the default route. Note that this + // is essentially just routing an unknown destination IP, because w/o any + // subnet prefix information a subnet broadcast address is just a normal IP. + { + name: "IPv4 Broadcast to unknown subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: defaultSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + }) + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) + } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { + t.Errorf("route mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestResolveWith(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + }) + ep := channel.New(0, defaultMTU, "") + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) + + remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4()) + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) + } + defer r.Release() + + // Should initially require resolution. + if !r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = false, want = true") + } + + // Manually resolving the route should no longer require resolution. + r.ResolveWith("\x01") + if r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = true, want = false") + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 97a1aec4b..b902c6ca9 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -17,12 +17,12 @@ package stack import ( "fmt" "math/rand" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" ) type protocolIDs struct { @@ -35,28 +35,130 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]*endpointsByNic + endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint } -type endpointsByNic struct { +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + eps.mu.Lock() + defer eps.mu.Unlock() + epsByNIC, ok := eps.endpoints[id] + if !ok { + return + } + if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) { + return + } + delete(eps.endpoints, id) +} + +func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { + eps.mu.RLock() + defer eps.mu.RUnlock() + es := make([]TransportEndpoint, 0, len(eps.endpoints)) + for _, e := range eps.endpoints { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in +// descending order of match quality. If a call to yield returns false, +// iterEndpointsLocked stops iteration and returns immediately. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { + // Try to find a match with the id as provided. + if ep, ok := eps.endpoints[id]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the local address. + nid := id + + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the remote part. + nid.LocalAddress = id.LocalAddress + nid.RemoteAddress = "" + nid.RemotePort = 0 + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with only the local port. + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } +} + +// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in +// descending order of match quality. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { + var matchedEPs []*endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEPs = append(matchedEPs, ep) + return true + }) + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given id. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { + var matchedEP *endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEP = ep + return false + }) + return matchedEP +} + +type endpointsByNIC struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 } +func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() + var eps []TransportEndpoint + for _, ep := range epsByNIC.endpoints { + eps = append(eps, ep.transportEndpoints()...) + } + return eps +} + // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - epsByNic.mu.RLock() +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { + epsByNIC.mu.RLock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } } @@ -64,24 +166,30 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { - mpep.handlePacketAll(r, id, vv) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + mpep.handlePacketAll(r, id, pkt) + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } - // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + transEP := selectEndpoint(id, mpep, epsByNIC.seed) + if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { + queuedProtocol.QueuePacket(r, transEP, id, pkt) + epsByNIC.mu.RUnlock() + return + } + + transEP.HandlePacket(r, id, pkt) + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() - mpep, ok := epsByNic.endpoints[n.ID()] + mpep, ok := epsByNIC.endpoints[n.ID()] if !ok { - mpep, ok = epsByNic.endpoints[0] + mpep, ok = epsByNIC.endpoints[0] } if !ok { return @@ -91,55 +199,52 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) + selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() - if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { - // There was already a bind. - return multiPortEp.singleRegisterEndpoint(t, reusePort) + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] + if !ok { + multiPortEp = &multiPortEndpoint{ + demux: d, + netProto: netProto, + transProto: transProto, + } + epsByNIC.endpoints[bindToDevice] = multiPortEp } - // This is a new binding. - multiPortEp := &multiPortEndpoint{} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.reuse = reusePort - epsByNic.endpoints[bindToDevice] = multiPortEp - return multiPortEp.singleRegisterEndpoint(t, reusePort) + return multiPortEp.singleRegisterEndpoint(t, flags) } -// unregisterEndpoint returns true if endpointsByNic has to be unregistered. -func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() - multiPortEp, ok := epsByNic.endpoints[bindToDevice] +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() + + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { - return false - } - if multiPortEp.unregisterEndpoint(t) { - delete(epsByNic.endpoints, bindToDevice) + return nil } - return len(epsByNic.endpoints) == 0 + + return multiPortEp.singleCheckEndpoint(flags) } -// unregisterEndpoint unregisters the endpoint with the given id such that it -// won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - eps.mu.Lock() - defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] +// unregisterEndpoint returns true if endpointsByNIC has to be unregistered. +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { - return + return false } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { - return + if multiPortEp.unregisterEndpoint(t, flags) { + delete(epsByNIC.endpoints, bindToDevice) } - delete(eps.endpoints, id) + return len(epsByNIC.endpoints) == 0 } // transportDemuxer demultiplexes packets targeted at a transport endpoint @@ -149,17 +254,33 @@ func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep Tra // newTransportDemuxer. type transportDemuxer struct { // protocol is immutable. - protocol map[protocolIDs]*transportEndpoints + protocol map[protocolIDs]*transportEndpoints + queuedProtocols map[protocolIDs]queuedTransportProtocol +} + +// queuedTransportProtocol if supported by a protocol implementation will cause +// the dispatcher to delivery packets to the QueuePacket method instead of +// calling HandlePacket directly on the endpoint. +type queuedTransportProtocol interface { + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { - d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + d := &transportDemuxer{ + protocol: make(map[protocolIDs]*transportEndpoints), + queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), + } // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]*endpointsByNic), + protoIDs := protocolIDs{netProto, proto} + d.protocol[protoIDs] = &transportEndpoints{ + endpoints: make(map[TransportEndpointID]*endpointsByNIC), + } + qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) + if isQueued { + d.queuedProtocols[protoIDs] = qTransProto } } } @@ -169,10 +290,21 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) + return err + } + } + + return nil +} + +// checkEndpoint checks if an endpoint can be registered with the dispatcher. +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + for _, n := range netProtos { + if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { return err } } @@ -183,12 +315,29 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. +// +// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save +// this to ensure that the underlying endpoints get saved/restored, but not not +// use the restored copy. +// +// +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex - endpointsArr []TransportEndpoint - endpointsMap map[TransportEndpoint]int - // reuse indicates if more than one endpoint is allowed. - reuse bool + mu sync.RWMutex `state:"nosave"` + demux *transportDemuxer + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + + // endpoints stores the transport endpoints in the order in which they + // were bound. This is required for UDP SO_REUSEADDR. + endpoints []TransportEndpoint + flags ports.FlagCounter +} + +func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpoints...) + ep.mu.RUnlock() + return eps } // reciprocalScale scales a value into range [0, n). @@ -203,8 +352,12 @@ func reciprocalScale(val, n uint32) uint32 { // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpointsArr) == 1 { - return mpep.endpointsArr[0] + if len(mpep.endpoints) == 1 { + return mpep.endpoints[0] + } + + if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent { + return mpep.endpoints[len(mpep.endpoints)-1] } payload := []byte{ @@ -220,72 +373,89 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) - return mpep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) + return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() - for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy except for - // the final one. - if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) - break + queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. + for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + } else { + endpoint.HandlePacket(r, id, pkt.Clone()) } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + } + if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + } else { + endpoint.HandlePacket(r, id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if len(ep.endpointsArr) > 0 { + bits := flags.Bits() & ports.MultiBindFlagMask + + if len(ep.endpoints) != 0 { + // If it was previously bound, we need to check if we can bind again. + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + return tcpip.ErrPortInUse + } + } + + ep.endpoints = append(ep.endpoints, t) + ep.flags.AddRef(bits) + + return nil +} + +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { + ep.mu.RLock() + defer ep.mu.RUnlock() + + bits := flags.Bits() & ports.MultiBindFlagMask + + if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if !ep.reuse || !reusePort { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { return tcpip.ErrPortInUse } } - // A new endpoint is added into endpointsArr and its index there is saved in - // endpointsMap. This will allow us to remove endpoint from the array fast. - ep.endpointsMap[t] = len(ep.endpointsArr) - ep.endpointsArr = append(ep.endpointsArr, t) return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. -func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { +func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() defer ep.mu.Unlock() - idx, ok := ep.endpointsMap[t] - if !ok { - return false - } - delete(ep.endpointsMap, t) - l := len(ep.endpointsArr) - if l > 1 { - // The last endpoint in endpointsArr is moved instead of the deleted one. - lastEp := ep.endpointsArr[l-1] - ep.endpointsArr[idx] = lastEp - ep.endpointsMap[lastEp] = idx - ep.endpointsArr = ep.endpointsArr[0 : l-1] - return false + for i, endpoint := range ep.endpoints { + if endpoint == t { + copy(ep.endpoints[i:], ep.endpoints[i+1:]) + ep.endpoints[len(ep.endpoints)-1] = nil + ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] + + ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask) + break + } } - return true + return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { - // TODO(eyalsoha): Why? - reusePort = false + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] @@ -296,82 +466,109 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - if epsByNic, ok := eps.endpoints[id]; ok { - // There was already a binding. - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + epsByNIC, ok := eps.endpoints[id] + if !ok { + epsByNIC = &endpointsByNIC{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), + } + eps.endpoints[id] = epsByNIC + } + + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) +} + +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return tcpip.ErrUnknownProtocol } - // This is a new binding. - epsByNic := &endpointsByNic{ - endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + eps.mu.RLock() + defer eps.mu.RUnlock() + + epsByNIC, ok := eps.endpoints[id] + if !ok { + return nil } - eps.endpoints[id] = epsByNic - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep, bindToDevice) + eps.unregisterEndpoint(id, ep, flags, bindToDevice) } } } -var loopbackSubnet = func() tcpip.Subnet { - sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") - if err != nil { - panic(err) - } - return sn -}() - // deliverPacket attempts to find one or more matching transport endpoints, and -// then, if matches are found, delivers the packet to them. Returns true if it -// found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { +// then, if matches are found, delivers the packet to them. Returns true if +// the packet no longer needs to be handled. +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false } - eps.mu.RLock() - - // Determine which transport endpoint or endpoints to deliver this packet to. - // If the packet is a broadcast or multicast, then find all matching + // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - var destEps []*endpointsByNic if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, vv, id) - } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { - destEps = append(destEps, ep) + eps.mu.RLock() + destEPs := eps.findAllEndpointsLocked(id) + eps.mu.RUnlock() + // Fail if we didn't find at least one matching transport endpoint. + if len(destEPs) == 0 { + r.Stats().UDP.UnknownPortErrors.Increment() + return false + } + // handlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEPs[:len(destEPs)-1] { + ep.handlePacket(r, id, pkt.Clone()) + } + destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + return true } - eps.mu.RUnlock() + // If the packet is a TCP packet with a non-unicast source or destination + // address, then do nothing further and instruct the caller to do the same. + if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single source and a + // single destination; the addresses must be unicast. + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true + } - // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 { - // UDP packet could not be delivered to an unknown destination port. + eps.mu.RLock() + ep := eps.findEndpointLocked(id) + eps.mu.RUnlock() + if ep == nil { if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() } return false } - - // Deliver the packet. - for _, ep := range destEps { - ep.handlePacket(r, id, vv) - } - + ep.handlePacket(r, id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -385,7 +582,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + rawEP.HandlePacket(r, pkt) foundRaw = true } eps.mu.RUnlock() @@ -395,67 +592,51 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false } - // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, vv, id) + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() - - // Fail if we didn't find one. if ep == nil { return false } - // Deliver the packet. - ep.handleControlPacket(n, id, typ, extra, vv) - + ep.handleControlPacket(n, id, typ, extra, pkt) return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { - var matchedEPs []*endpointsByNic - // Try to find a match with the id as provided. - if ep, ok := eps.endpoints[id]; ok { - matchedEPs = append(matchedEPs, ep) +// findTransportEndpoint find a single endpoint that most closely matches the provided id. +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + return nil } - // Try to find a match with the id minus the local address. - nid := id - - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + eps.mu.RLock() + epsByNIC := eps.findEndpointLocked(id) + if epsByNIC == nil { + eps.mu.RUnlock() + return nil } - // Try to find a match with the id minus the remote part. - nid.LocalAddress = id.LocalAddress - nid.RemoteAddress = "" - nid.RemotePort = 0 - if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) - } + epsByNIC.mu.RLock() + eps.mu.RUnlock() - // Try to find a match with only the local port. - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. + return nil + } } - return matchedEPs -} - -// findEndpointLocked returns the endpoint that most closely matches the given -// id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { - return matchedEPs[0] - } - return nil + ep := selectEndpoint(id, mpep, epsByNIC.seed) + epsByNIC.mu.RUnlock() + return ep } // registerRawEndpoint registers the given endpoint with the dispatcher such @@ -469,8 +650,8 @@ func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNum } eps.mu.Lock() - defer eps.mu.Unlock() eps.rawEndpoints = append(eps.rawEndpoints, ep) + eps.mu.Unlock() return nil } @@ -484,15 +665,22 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } eps.mu.Lock() - defer eps.mu.Unlock() for i, rawEP := range eps.rawEndpoints { if rawEP == ep { - eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) - return + lastIdx := len(eps.rawEndpoints) - 1 + eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx] + eps.rawEndpoints[lastIdx] = nil + eps.rawEndpoints = eps.rawEndpoints[:lastIdx] + break } } + eps.mu.Unlock() } func isMulticastOrBroadcast(addr tcpip.Address) bool { return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) } + +func isUnicast(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 210233dc0..1339edc2d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -25,96 +25,65 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testPort = 4096 + testSrcAddrV4 = "\x0a\x00\x00\x01" + testDstAddrV4 = "\x0a\x00\x00\x02" + + testDstPort = 1234 + testSrcPort = 4096 ) type testContext struct { - t *testing.T - linkEPs map[string]*channel.Endpoint + linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue + wq waiter.Queue } -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createV6Endpoint(v6only bool) { - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } -} - -// newDualTestContextMultiNic creates the testing context and also linkEpNames -// named NICs. -func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { +// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. +func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEPs := make(map[string]*channel.Endpoint) - for i, linkEpName := range linkEpNames { - channelEP := channel.New(256, mtu, "") - nicid := tcpip.NICID(i + 1) - if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + linkEps := make(map[tcpip.NICID]*channel.Endpoint) + for _, linkEpID := range linkEpIDs { + channelEp := channel.New(256, mtu, "") + if err := s.CreateNIC(linkEpID, channelEp); err != nil { + t.Fatalf("CreateNIC failed: %s", err) } - linkEPs[linkEpName] = channelEP + linkEps[linkEpID] = channelEp - if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress IPv4 failed: %v", err) + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { + t.Fatalf("AddAddress IPv4 failed: %s", err) } - if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress IPv6 failed: %v", err) + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { + t.Fatalf("AddAddress IPv6 failed: %s", err) } } s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, + {Destination: header.IPv4EmptySubnet, NIC: 1}, + {Destination: header.IPv6EmptySubnet, NIC: 1}, }) return &testContext{ - t: t, s: s, - linkEPs: linkEPs, + linkEps: linkEps, } } type headers struct { - srcPort uint16 - dstPort uint16 + srcPort, dstPort uint16 } func newPayload() []byte { @@ -125,7 +94,47 @@ func newPayload() []byte { return b } -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { +func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { + buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: 0x80, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: testSrcAddrV4, + DstAddr: testDstAddrV4, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + }) + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -136,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -149,14 +158,17 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + }) + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) } func TestTransportDemuxerRegister(t *testing.T) { @@ -171,95 +183,105 @@ func TestTransportDemuxerRegister(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) } } -// TestReuseBindToDevice injects varied packets on input devices and checks that +// TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. -func TestDistribution(t *testing.T) { +func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { - reuse int - bindToDevice string + reuse bool + bindToDevice tcpip.NICID } for _, test := range []struct { name string // endpoints will received the inject packets. endpoints []endpointSockopts - // wantedDistribution is the wanted ratio of packets received on each + // wantDistributions is the want ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[string][]float64 + wantDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{0, "dev0"}, - endpointSockopts{0, "dev1"}, - endpointSockopts{0, "dev2"}, + {reuse: false, bindToDevice: 1}, + {reuse: false, bindToDevice: 2}, + {reuse: false, bindToDevice: 3}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": []float64{1, 0, 0}, + 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": []float64{0, 1, 0}, + 2: {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": []float64{0, 0, 1}, + 3: {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, ""}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + 1: {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": []float64{0, 0, 0, 0, 0, 1}, + 1000: {0, 0, 0, 0, 0, 1}, }, }, } { - t.Run(test.name, func(t *testing.T) { - for device, wantedDistribution := range test.wantedDistributions { - t.Run(device, func(t *testing.T) { - var devices []string - for d := range test.wantedDistributions { + for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } { + for device, wantDistribution := range test.wantDistributions { + t.Run(test.name+protoName+string(device), func(t *testing.T) { + var devices []tcpip.NICID + for d := range test.wantDistributions { devices = append(devices, d) } - c := newDualTestContextMultiNic(t, defaultMTU, devices) - defer c.cleanup() - - c.createV6Endpoint(false) + c := newDualTestContextMultiNIC(t, defaultMTU, devices) eps := make(map[tcpip.Endpoint]int) @@ -273,9 +295,9 @@ func TestDistribution(t *testing.T) { defer close(ch) var err *tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } eps[ep] = i @@ -286,22 +308,31 @@ func TestDistribution(t *testing.T) { }(ep) defer ep.Close() - reusePortOption := tcpip.ReusePortOption(endpoint.reuse) - if err := ep.SetSockOpt(reusePortOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { + t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) + } + + var dstAddr tcpip.Address + switch netProtoNum { + case ipv4.ProtocolNumber: + dstAddr = testDstAddrV4 + case ipv6.ProtocolNumber: + dstAddr = testDstAddrV6 + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) } - if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } npackets := 100000 nports := 10000 - if got, want := len(test.endpoints), len(wantedDistribution); got != want { + if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } ports := make(map[uint16]tcpip.Endpoint) @@ -310,17 +341,22 @@ func TestDistribution(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, - &headers{ - srcPort: testPort + port, - dstPort: stackPort}, - device) + hdrs := &headers{ + srcPort: testSrcPort + port, + dstPort: testDstPort, + } + switch netProtoNum { + case ipv4.ProtocolNumber: + c.sendV4Packet(payload, hdrs, device) + case ipv6.ProtocolNumber: + c.sendV6Packet(payload, hdrs, device) + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) + } - var addr tcpip.FullAddress ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + if _, _, err := ep.Read(nil); err != nil { + t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ if i < nports { @@ -336,17 +372,17 @@ func TestDistribution(t *testing.T) { // Check that a packet distribution is as expected. for ep, i := range eps { - wantedRatio := wantedDistribution[i] - wantedRecv := wantedRatio * float64(npackets) + wantRatio := wantDistribution[i] + wantRecv := wantRatio * float64(npackets) actualRecv := stats[ep] actualRatio := float64(stats[ep]) / float64(npackets) // The deviation is less than 10%. - if math.Abs(actualRatio-wantedRatio) > 0.05 { - t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + if math.Abs(actualRatio-wantRatio) > 0.05 { + t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) } } }) } - }) + } } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 6d3daed24..fa4b14ba6 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -19,9 +19,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -43,6 +43,7 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint @@ -56,8 +57,14 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} +func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} + +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} +} + +func (f *fakeTransportEndpoint) Abort() { + f.Close() } func (f *fakeTransportEndpoint) Close() { @@ -77,12 +84,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, + Data: buffer.View(v).ToVectorisedView(), + }) + _ = pkt.TransportHeader().Push(fakeTransHeaderLen) + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, nil, err } @@ -98,13 +109,23 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptBool sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } @@ -134,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -144,6 +165,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -175,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, /* reuse */ - 0, /* bindtoDevice */ + ports.Flags{}, + 0, /* bindtoDevice */ ); err != nil { return err } @@ -192,7 +217,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -209,7 +234,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -218,15 +243,15 @@ func (f *fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { -} +func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} -func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { - return iptables.IPTables{}, nil +func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) { + return stack.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool @@ -251,10 +276,10 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -266,7 +291,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -292,6 +317,21 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +// Abort implements TransportProtocol.Abort. +func (*fakeTransportProtocol) Abort() {} + +// Close implements tcpip.Endpoint.Close. +func (*fakeTransportProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeTransportProtocol) Wait() {} + +// Parse implements TransportProtocol.Parse. +func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { + _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) + return ok +} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } @@ -337,7 +377,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -346,7 +388,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -355,7 +399,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -408,7 +454,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -417,7 +465,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -426,7 +476,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -579,7 +631,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: req.ToVectorisedView(), + })) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -591,17 +645,16 @@ func TestTransportForwarding(t *testing.T) { t.Fatalf("Write failed: %v", err) } - var p channel.PacketInfo - select { - case p = <-ep2.C: - default: + p, ok := ep2.Read() + if !ok { t.Fatal("Response packet not forwarded") } - if dst := p.Header[0]; dst != 3 { + nh := stack.PayloadSince(p.Pkt.NetworkHeader()) + if dst := nh[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Header[1]; src != 1 { + if src := nh[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 03be7d3d4..07c85ce59 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -35,15 +35,17 @@ import ( "reflect" "strconv" "strings" - "sync" "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" ) +// Using header.IPv4AddressSize would cause an import cycle. +const ipv4AddressSize = 4 + // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // @@ -111,6 +113,71 @@ var ( ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ) +var messageToError map[string]*Error + +var populate sync.Once + +// StringToError converts an error message to the error. +func StringToError(s string) *Error { + populate.Do(func() { + var errors = []*Error{ + ErrUnknownProtocol, + ErrUnknownNICID, + ErrUnknownDevice, + ErrUnknownProtocolOption, + ErrDuplicateNICID, + ErrDuplicateAddress, + ErrNoRoute, + ErrBadLinkEndpoint, + ErrAlreadyBound, + ErrInvalidEndpointState, + ErrAlreadyConnecting, + ErrAlreadyConnected, + ErrNoPortAvailable, + ErrPortInUse, + ErrBadLocalAddress, + ErrClosedForSend, + ErrClosedForReceive, + ErrWouldBlock, + ErrConnectionRefused, + ErrTimeout, + ErrAborted, + ErrConnectStarted, + ErrDestinationRequired, + ErrNotSupported, + ErrQueueSizeNotSupported, + ErrNotConnected, + ErrConnectionReset, + ErrConnectionAborted, + ErrNoSuchFile, + ErrInvalidOptionValue, + ErrNoLinkAddress, + ErrBadAddress, + ErrNetworkUnreachable, + ErrMessageTooLong, + ErrNoBufferSpace, + ErrBroadcastDisabled, + ErrNotPermitted, + ErrAddressFamilyNotSupported, + } + + messageToError = make(map[string]*Error) + for _, e := range errors { + if messageToError[e.String()] != nil { + panic("tcpip errors with duplicated message: " + e.String()) + } + messageToError[e.String()] = e + } + }) + + e, ok := messageToError[s] + if !ok { + panic("unknown error message: " + s) + } + + return e +} + // Errors related to Subnet var ( errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") @@ -128,7 +195,7 @@ func (e ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } -// A Clock provides the current time. +// A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. @@ -139,6 +206,31 @@ type Clock interface { // NowMonotonic returns a monotonic time value. NowMonotonic() int64 + + // AfterFunc waits for the duration to elapse and then calls f in its own + // goroutine. It returns a Timer that can be used to cancel the call using + // its Stop method. + AfterFunc(d time.Duration, f func()) Timer +} + +// Timer represents a single event. A Timer must be created with +// Clock.AfterFunc. +type Timer interface { + // Stop prevents the Timer from firing. It returns true if the call stops the + // timer, false if the timer has already expired or been stopped. + // + // If Stop returns false, then the timer has already expired and the function + // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop + // does not wait for f to complete before returning. If the caller needs to + // know whether f is completed, it must coordinate with f explicitly. + Stop() bool + + // Reset changes the timer to expire after duration d. + // + // Reset should be invoked only on stopped or expired timers. If the timer is + // known to have expired, Reset can be used directly. Otherwise, the caller + // must coordinate with the function f of Clock.AfterFunc(d, f). + Reset(d time.Duration) } // Address is a byte slice cast as a string that represents the address of a @@ -231,6 +323,36 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// IsBroadcast returns true if the address is considered a broadcast address. +func (s *Subnet) IsBroadcast(address Address) bool { + // Only IPv4 supports the notion of a broadcast address. + if len(address) != ipv4AddressSize { + return false + } + + // Normally, we would just compare address with the subnet's broadcast + // address but there is an exception where a simple comparison is not + // correct. This exception is for /31 and /32 IPv4 subnets where all + // addresses are considered valid host addresses. + // + // For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that + // both addresses in a /31 subnet "MUST be interpreted as host addresses." + // + // For /32, the case is a bit more vague. RFC 3021 makes no mention of /32 + // subnets. However, the same reasoning applies - if an exception is not + // made, then there do not exist any host addresses in a /32 subnet. RFC + // 4632 Section 3.1 also vaguely implies this interpretation by referring + // to addresses in /32 subnets as "host routes." + return s.Prefix() <= 30 && s.Broadcast() == address +} + +// Equal returns true if s equals o. +// +// Needed to use cmp.Equal on Subnet as its fields are unexported. +func (s Subnet) Equal(o Subnet) bool { + return s == o +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 @@ -245,6 +367,28 @@ const ( ShutdownWrite ) +// PacketType is used to indicate the destination of the packet. +type PacketType uint8 + +const ( + // PacketHost indicates a packet addressed to the local host. + PacketHost PacketType = iota + + // PacketOtherHost indicates an outgoing packet addressed to + // another host caught by a NIC in promiscuous mode. + PacketOtherHost + + // PacketOutgoing for a packet originating from the local host + // that is looped back to a packet socket. + PacketOutgoing + + // PacketBroadcast indicates a link layer broadcast packet. + PacketBroadcast + + // PacketMulticast indicates a link layer multicast packet. + PacketMulticast +) + // FullAddress represents a full transport node address, as required by the // Connect() and Bind() methods. // @@ -301,7 +445,7 @@ type ControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packed used to create + // Timestamp is the time (in ns) that the last packet used to create // the read data was received. Timestamp int64 @@ -310,6 +454,33 @@ type ControlMessages struct { // Inq is the number of bytes ready to be received. Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS uint8 + + // HasTClass indicates whether TClass is valid/set. + HasTClass bool + + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo IPPacketInfo +} + +// PacketOwner is used to get UID and GID of the packet. +type PacketOwner interface { + // UID returns UID of the packet. + UID() uint32 + + // GID returns GID of the packet. + GID() uint32 } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -317,9 +488,15 @@ type ControlMessages struct { // networking stack. type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources - // associated with it. + // associated with it. Close initiates the teardown process, the + // Endpoint may not be fully closed when Close returns. Close() + // Abort initiates an expedited endpoint teardown. As compared to + // Close, Abort prioritizes closing the Endpoint quickly over cleanly. + // Abort is best effort; implementing Abort with Close is acceptable. + Abort() + // Read reads data from the endpoint and optionally returns the sender. // // This method does not block if there is no data pending. It will also @@ -404,17 +581,25 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptBool sets a socket option, for simple cases where a value + // has the bool type. + SetSockOptBool(opt SockOptBool, v bool) *Error + // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. - SetSockOptInt(opt SockOpt, v int) *Error + SetSockOptInt(opt SockOptInt, v int) *Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error + // GetSockOptBool gets a socket option for simple cases where a return + // value has the bool type. + GetSockOptBool(SockOptBool) (bool, *Error) + // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. - GetSockOptInt(SockOpt) (int, *Error) + GetSockOptInt(SockOptInt) (int, *Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -427,14 +612,36 @@ type Endpoint interface { // NOTE: This method is a no-op for sockets other than TCP. ModerateRecvBuf(copied int) - // IPTables returns the iptables for this endpoint's stack. - IPTables() (iptables.IPTables, error) - // Info returns a copy to the transport endpoint info. Info() EndpointInfo // Stats returns a reference to the endpoint stats. Stats() EndpointStats + + // SetOwner sets the task owner to the endpoint owner. + SetOwner(owner PacketOwner) +} + +// LinkPacketInfo holds Link layer information for a received packet. +// +// +stateify savable +type LinkPacketInfo struct { + // Protocol is the NetworkProtocolNumber for the packet. + Protocol NetworkProtocolNumber + + // PktType is used to indicate the destination of the packet. + PktType PacketType +} + +// PacketEndpoint are additional methods that are only implemented by Packet +// endpoints. +type PacketEndpoint interface { + // ReadPacket reads a datagram/packet from the endpoint and optionally + // returns the sender and additional LinkPacketInfo. + // + // This method does not block if there is no data pending. It will also + // either return an error or data, never both. + ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) } // EndpointInfo is the interface implemented by each endpoint info struct. @@ -469,13 +676,117 @@ type WriteOptions struct { Atomic bool } -// SockOpt represents socket options which values have the int type. -type SockOpt int +// SockOptBool represents socket options which values have the bool type. +type SockOptBool int const ( + // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify + // whether datagram sockets are allowed to send packets to a broadcast + // address. + BroadcastOption SockOptBool = iota + + // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if + // data should be held until segments are full by the TCP transport + // protocol. + CorkOption + + // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if + // data should be sent out immediately by the transport protocol. For + // TCP, it determines if the Nagle algorithm is on or off. + DelayOption + + // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to + // specify whether TCP keepalive is enabled for this socket. + KeepaliveEnabledOption + + // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to + // specify whether multicast packets sent over a non-loopback interface + // will be looped back. + MulticastLoopOption + + // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify + // whether UDP checksum is disabled for this socket. + NoChecksumOption + + // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify + // whether SCM_CREDENTIALS socket control messages are enabled. + // + // Only supported on Unix sockets. + PasscredOption + + // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. + QuickAckOption + + // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to + // specify if the IPV6_TCLASS ancillary message is passed with incoming + // packets. + ReceiveTClassOption + + // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify + // if the TOS ancillary message is passed with incoming packets. + ReceiveTOSOption + + // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to + // specify if more inforamtion is provided with incoming packets such as + // interface index and address. + ReceiveIPPacketInfoOption + + // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to + // specify whether Bind() should allow reuse of local address. + ReuseAddressOption + + // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit + // multiple sockets to be bound to an identical socket address. + ReusePortOption + + // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify + // whether an IPv6 socket is to be restricted to sending and receiving + // IPv6 packets only. + V6OnlyOption + + // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw + // endpoint that all packets being written have an IP header and the + // endpoint should not attach an IP header. + IPHdrIncludedOption +) + +// SockOptInt represents socket options which values have the int type. +type SockOptInt int + +const ( + // KeepaliveCountOption is used by SetSockOptInt/GetSockOptInt to + // specify the number of un-ACKed TCP keepalives that will be sent + // before the connection is closed. + KeepaliveCountOption SockOptInt = iota + + // IPv4TOSOption is used by SetSockOptInt/GetSockOptInt to specify TOS + // for all subsequent outgoing IPv4 packets from the endpoint. + IPv4TOSOption + + // IPv6TrafficClassOption is used by SetSockOptInt/GetSockOptInt to + // specify TOS for all subsequent outgoing IPv6 packets from the + // endpoint. + IPv6TrafficClassOption + + // MaxSegOption is used by SetSockOptInt/GetSockOptInt to set/get the + // current Maximum Segment Size(MSS) value as specified using the + // TCP_MAXSEG option. + MaxSegOption + + // MTUDiscoverOption is used to set/get the path MTU discovery setting. + // + // NOTE: Setting this option to any other value than PMTUDiscoveryDont + // is not supported and will fail as such, and getting this option will + // always return PMTUDiscoveryDont. + MTUDiscoverOption + + // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control + // the default TTL value for multicast messages. The default is 1. + MulticastTTLOption + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOpt = iota + ReceiveQueueSizeOption // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. @@ -489,47 +800,52 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption - // DelayOption is used by SetSockOpt/GetSockOpt to specify if data - // should be sent out immediately by the transport protocol. For TCP, - // it determines if the Nagle algorithm is on or off. - DelayOption + // TTLOption is used by SetSockOptInt/GetSockOptInt to control the + // default TTL/hop limit value for unicast messages. The default is + // protocol specific. + // + // A zero value indicates the default. + TTLOption - // TODO(b/137664753): convert all int socket options to be handled via - // GetSockOptInt. + // TCPSynCountOption is used by SetSockOptInt/GetSockOptInt to specify + // the number of SYN retransmits that TCP should send before aborting + // the attempt to connect. It cannot exceed 255. + // + // NOTE: This option is currently only stubbed out and is no-op. + TCPSynCountOption + + // TCPWindowClampOption is used by SetSockOptInt/GetSockOptInt to bound + // the size of the advertised window to this value. + // + // NOTE: This option is currently only stubed out and is a no-op + TCPWindowClampOption ) -// ErrorOption is used in GetSockOpt to specify that the last error reported by -// the endpoint should be cleared and returned. -type ErrorOption struct{} +const ( + // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use + // per-route settings. + PMTUDiscoveryWant int = iota -// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 -// socket is to be restricted to sending and receiving IPv6 packets only. -type V6OnlyOption int + // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable + // path MTU discovery. + PMTUDiscoveryDont -// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be -// held until segments are full by the TCP transport protocol. -type CorkOption int + // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do + // path MTU discovery. + PMTUDiscoveryDo -// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() -// should allow reuse of local address. -type ReuseAddressOption int + // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF + // but ignore path MTU. + PMTUDiscoveryProbe +) -// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets -// to be bound to an identical socket address. -type ReusePortOption int +// ErrorOption is used in GetSockOpt to specify that the last error reported by +// the endpoint should be cleared and returned. +type ErrorOption struct{} // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. -type BindToDeviceOption string - -// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. -type QuickAckOption int - -// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether -// SCM_CREDENTIALS socket control messages are enabled. -// -// Only supported on Unix sockets. -type PasscredOption int +type BindToDeviceOption NICID // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // @@ -539,10 +855,6 @@ type TCPInfoOption struct { RTTVar time.Duration } -// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP keepalive is enabled for this socket. -type KeepaliveEnabledOption int - // KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a // connection must remain idle before the first TCP keepalive packet is sent. // Once this time is reached, KeepaliveIntervalOption is used instead. @@ -552,10 +864,10 @@ type KeepaliveIdleOption time.Duration // interval between sending TCP keepalive packets. type KeepaliveIntervalOption time.Duration -// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number -// of un-ACKed TCP keepalives that will be sent before the connection is -// closed. -type KeepaliveCountOption int +// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user +// specified timeout for a given TCP connection. +// See: RFC5482 for details. +type TCPUserTimeoutOption time.Duration // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. @@ -565,23 +877,45 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive -// buffer moderation. +// ModerateReceiveBufferOption is used by buffer moderation. type ModerateReceiveBufferOption bool -// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current -// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. -type MaxSegOption int +// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state +// before being marked closed. +type TCPLingerTimeoutOption time.Duration -// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop -// limit value for unicast messages. The default is protocol specific. -// -// A zero value indicates the default. -type TTLOption uint8 +// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TIME_WAIT state +// before being marked closed. +type TCPTimeWaitTimeoutOption time.Duration + +// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a +// accept to return a completed connection only when there is data to be +// read. This usually means the listening socket will drop the final ACK +// for a handshake till the specified timeout until a segment with data arrives. +type TCPDeferAcceptOption time.Duration + +// TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding +// default MinRTO used by the Stack. +type TCPMinRTOOption time.Duration + +// TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding +// default MaxRTO used by the Stack. +type TCPMaxRTOOption time.Duration -// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default -// TTL value for multicast messages. The default is 1. -type MulticastTTLOption uint8 +// TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum number of retransmits after which we time out the connection. +type TCPMaxRetriesOption uint64 + +// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify +// the number of endpoints that can be in SYN-RCVD state before the stack +// switches to using SYN cookies. +type TCPSynRcvdCountThresholdOption uint64 + +// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide +// default for number of times SYN is retransmitted before aborting a connect. +type TCPSynRetriesOption uint8 // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. @@ -590,10 +924,6 @@ type MulticastInterfaceOption struct { InterfaceAddr Address } -// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether -// multicast packets sent over a non-loopback interface will be looped back. -type MulticastLoopOption bool - // MembershipOption is used by SetSockOpt/GetSockOpt as an argument to // AddMembershipOption and RemoveMembershipOption. type MembershipOption struct { @@ -616,21 +946,51 @@ type RemoveMembershipOption MembershipOption // TCP out-of-band data is delivered along with the normal in-band data. type OutOfBandInlineOption int -// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether -// datagram sockets are allowed to send packets to a broadcast address. -type BroadcastOption int - // DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify // a default TTL. type DefaultTTLOption uint8 -// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv4 packets from the endpoint. -type IPv4TOSOption uint8 +// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached +// classic BPF filter on a given endpoint. +type SocketDetachFilterOption int + +// OriginalDestinationOption is used to get the original destination address +// and port of a redirected packet. +type OriginalDestinationOption FullAddress + +// TCPTimeWaitReuseOption is used stack.(*Stack).TransportProtocolOption to +// specify if the stack can reuse the port bound by an endpoint in TIME-WAIT for +// new connections when it is safe from protocol viewpoint. +type TCPTimeWaitReuseOption uint8 + +const ( + // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot + // be reused for new connections. + TCPTimeWaitReuseDisabled TCPTimeWaitReuseOption = iota + + // TCPTimeWaitReuseGlobal indicates reuse of port bound by endponts in TIME-WAIT can + // be reused for new connections irrespective of the src/dest addresses. + TCPTimeWaitReuseGlobal + + // TCPTimeWaitReuseLoopbackOnly indicates reuse of port bound by endpoint in TIME-WAIT can + // only be reused if the connection was a connection over loopback. i.e src/dest adddresses + // are loopback addresses. + TCPTimeWaitReuseLoopbackOnly +) + +// IPPacketInfo is the message structure for IP_PKTINFO. +// +// +stateify savable +type IPPacketInfo struct { + // NIC is the ID of the NIC to be used. + NIC NICID + + // LocalAddr is the local address. + LocalAddr Address -// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv6 packets from the endpoint. -type IPv6TrafficClassOption uint8 + // DestinationAddr is the destination address found in the IP header. + DestinationAddr Address +} // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the @@ -852,9 +1212,13 @@ type IPStats struct { // link layer in nic.DeliverNetworkPacket. PacketsReceived *StatCounter - // InvalidAddressesReceived is the total number of IP packets received - // with an unknown or invalid destination address. - InvalidAddressesReceived *StatCounter + // InvalidDestinationAddressesReceived is the total number of IP packets + // received with an unknown or invalid destination address. + InvalidDestinationAddressesReceived *StatCounter + + // InvalidSourceAddressesReceived is the total number of IP packets received + // with a source address that should never have been received on the wire. + InvalidSourceAddressesReceived *StatCounter // PacketsDelivered is the total number of incoming IP packets that // are successfully delivered to the transport layer via HandlePacket. @@ -887,14 +1251,26 @@ type TCPStats struct { PassiveConnectionOpenings *StatCounter // CurrentEstablished is the number of TCP connections for which the - // current state is either ESTABLISHED or CLOSE-WAIT. + // current state is ESTABLISHED. CurrentEstablished *StatCounter + // CurrentConnected is the number of TCP connections that + // are in connected state. + CurrentConnected *StatCounter + // EstablishedResets is the number of times TCP connections have made // a direct transition to the CLOSED state from either the // ESTABLISHED state or the CLOSE-WAIT state. EstablishedResets *StatCounter + // EstablishedClosed is the number of times established TCP connections + // made a transition to CLOSED state. + EstablishedClosed *StatCounter + + // EstablishedTimedout is the number of times an established connection + // was reset because of keep-alive time out. + EstablishedTimedout *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter @@ -987,6 +1363,12 @@ type UDPStats struct { // PacketSendErrors is the number of datagrams failed to be sent. PacketSendErrors *StatCounter + + // ChecksumErrors is the number of datagrams dropped due to bad checksums. + ChecksumErrors *StatCounter + + // InvalidSourceAddress is the number of invalid sourced datagrams dropped. + InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. @@ -1030,6 +1412,9 @@ type ReceiveErrors struct { // ClosedReceiver is the number of received packets dropped because // of receiving endpoint state being closed. ClosedReceiver StatCounter + + // ChecksumErrors is the number of packets dropped due to bad checksums. + ChecksumErrors StatCounter } // SendErrors collects packet send errors within the transport layer for @@ -1055,6 +1440,10 @@ type ReadErrors struct { // InvalidEndpointState is the number of times we found the endpoint state // to be unexpected. InvalidEndpointState StatCounter + + // NotConnected is the number of times we tried to read but found that the + // endpoint was not connected. + NotConnected StatCounter } // WriteErrors collects packet write errors from an endpoint write call. @@ -1097,7 +1486,9 @@ type TransportEndpointStats struct { // marker interface. func (*TransportEndpointStats) IsEndpointStats() {} -func fillIn(v reflect.Value) { +// InitStatCounters initializes v's fields with nil StatCounter fields to new +// StatCounters. +func InitStatCounters(v reflect.Value) { for i := 0; i < v.NumField(); i++ { v := v.Field(i) if s, ok := v.Addr().Interface().(**StatCounter); ok { @@ -1105,14 +1496,14 @@ func fillIn(v reflect.Value) { *s = new(StatCounter) } } else { - fillIn(v) + InitStatCounters(v) } } } // FillIn returns a copy of s with nil fields initialized to new StatCounters. func (s Stats) FillIn() Stats { - fillIn(reflect.ValueOf(&s).Elem()) + InitStatCounters(reflect.ValueOf(&s).Elem()) return s } @@ -1322,8 +1713,8 @@ var ( // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { - es := make([]Endpoint, 0, len(danglingEndpoints)) danglingEndpointsMu.Lock() + es := make([]Endpoint, 0, len(danglingEndpoints)) for e := range danglingEndpoints { es = append(es, e) } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 8c0aacffa..1c8e2bc34 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -218,7 +218,7 @@ func TestAddressWithPrefixSubnet(t *testing.T) { gotSubnet := ap.Subnet() wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) if err != nil { - t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) continue } if gotSubnet != wantSubnet { diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD new file mode 100644 index 000000000..6d52af98a --- /dev/null +++ b/pkg/tcpip/tests/integration/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_test") + +package(licenses = ["notice"]) + +go_test( + name = "integration_test", + size = "small", + srcs = ["multicast_broadcast_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go new file mode 100644 index 000000000..9f0dd4d6d --- /dev/null +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -0,0 +1,438 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + defaultMTU = 1280 + ttl = 255 +) + +var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + ipv4Subnet = ipv4Addr.Subnet() + ipv4SubnetBcast = ipv4Subnet.Broadcast() + + ipv6Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("200a::1").To16()), + PrefixLen: 64, + } + ipv6Subnet = ipv6Addr.Subnet() + ipv6SubnetBcast = ipv6Subnet.Broadcast() + + // Remote addrs. + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + remoteIPv6Addr = tcpip.Address(net.ParseIP("200b::1").To16()) +) + +// TestPingMulticastBroadcast tests that responding to an Echo Request destined +// to a multicast or broadcast address uses a unicast source address for the +// reply. +func TestPingMulticastBroadcast(t *testing.T) { + const nicID = 1 + + rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ttl, + SrcAddr: remoteIPv4Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + tests := []struct { + name string + dstAddr tcpip.Address + }{ + { + name: "IPv4 unicast", + dstAddr: ipv4Addr.Address, + }, + { + name: "IPv4 directed broadcast", + dstAddr: ipv4SubnetBcast, + }, + { + name: "IPv4 broadcast", + dstAddr: header.IPv4Broadcast, + }, + { + name: "IPv4 all-systems multicast", + dstAddr: header.IPv4AllSystems, + }, + { + name: "IPv6 unicast", + dstAddr: ipv6Addr.Address, + }, + { + name: "IPv6 all-nodes multicast", + dstAddr: header.IPv6AllNodesMulticastAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ipv4Proto := ipv4.NewProtocol() + ipv6Proto := ipv6.NewProtocol() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()}, + }) + // We only expect a single packet in response to our ICMP Echo Request. + e := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + } + ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + } + + // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + var rxICMP func(*channel.Endpoint, tcpip.Address) + var expectedSrc tcpip.Address + var expectedDst tcpip.Address + var proto stack.NetworkProtocol + switch l := len(test.dstAddr); l { + case header.IPv4AddressSize: + rxICMP = rxIPv4ICMP + expectedSrc = ipv4Addr.Address + expectedDst = remoteIPv4Addr + proto = ipv4Proto + case header.IPv6AddressSize: + rxICMP = rxIPv6ICMP + expectedSrc = ipv6Addr.Address + expectedDst = remoteIPv6Addr + proto = ipv6Proto + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + rxICMP(e, test.dstAddr) + pkt, ok := e.Read() + if !ok { + t.Fatal("expected ICMP response") + } + + if pkt.Route.LocalAddress != expectedSrc { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc) + } + if pkt.Route.RemoteAddress != expectedDst { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) + } + + src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) + if src != expectedSrc { + t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) + } + if dst != expectedDst { + t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst) + } + }) + } + +} + +// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some +// multicast or broadcast address. +func TestIncomingMulticastAndBroadcast(t *testing.T) { + const ( + nicID = 1 + remotePort = 5555 + localPort = 80 + ) + + data := []byte{1, 2, 3, 4} + + rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + totalLen := header.IPv4MinimumSize + payloadLen + hdr := buffer.NewPrependable(totalLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(udp.ProtocolNumber), + TTL: ttl, + SrcAddr: remoteIPv4Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLen), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + tests := []struct { + name string + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool + }{ + { + name: "IPv4 unicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + { + name: "IPv4 unicast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4Addr.Address, + expectRx: false, + }, + { + name: "IPv4 unicast binding to wildcard", + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + + { + name: "IPv4 directed broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + { + name: "IPv4 directed broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4SubnetBcast, + expectRx: false, + }, + { + name: "IPv4 directed broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: header.IPv4Broadcast, + expectRx: true, + }, + { + name: "IPv4 broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: header.IPv4Broadcast, + expectRx: false, + }, + { + name: "IPv4 broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 all-systems multicast binding to all-systems multicast", + bindAddr: header.IPv4AllSystems, + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to wildcard", + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: header.IPv4AllSystems, + expectRx: false, + }, + + // IPv6 has no notion of a broadcast. + { + name: "IPv6 unicast binding to wildcard", + dstAddr: ipv6Addr.Address, + expectRx: true, + }, + { + name: "IPv6 broadcast-like address binding to wildcard", + dstAddr: ipv6SubnetBcast, + expectRx: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + } + ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + } + + var netproto tcpip.NetworkProtocolNumber + var rxUDP func(*channel.Endpoint, tcpip.Address) + switch l := len(test.dstAddr); l { + case header.IPv4AddressSize: + netproto = header.IPv4ProtocolNumber + rxUDP = rxIPv4UDP + case header.IPv6AddressSize: + netproto = header.IPv6ProtocolNumber + rxUDP = rxIPv6UDP + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", bindAddr, err) + } + + rxUDP(e, test.dstAddr) + if gotPayload, _, err := ep.Read(nil); test.expectRx { + if err != nil { + t.Fatalf("Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + } + }) + } +} diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index a52262e87..f32d58091 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,18 +13,20 @@ // limitations under the License. // +build go1.9 -// +build !go1.14 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. package tcpip import ( - _ "time" // Used with go:linkname. + "time" // Used with go:linkname. _ "unsafe" // Required for go:linkname. ) // StdClock implements Clock with the time package. +// +// +stateify savable type StdClock struct{} var _ Clock = (*StdClock)(nil) @@ -43,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 { _, _, mono := now() return mono } + +// AfterFunc implements Clock.AfterFunc. +func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go new file mode 100644 index 000000000..f1dd7c310 --- /dev/null +++ b/pkg/tcpip/timer.go @@ -0,0 +1,206 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +// jobInstance is a specific instance of Job. +// +// Different instances are created each time Job is scheduled so each timer has +// its own earlyReturn signal. This is to address a bug when a Job is stopped +// and reset in quick succession resulting in a timer instance's earlyReturn +// signal being affected or seen by another timer instance. +// +// Consider the following sceneario where timer instances share a common +// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a +// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second +// (B), third (C), and fourth (D) instance of the timer firing, respectively): +// T1: Obtain L +// T1: Create a new Job w/ lock L (create instance A) +// T2: instance A fires, blocked trying to obtain L. +// T1: Attempt to stop instance A (set earlyReturn = true) +// T1: Schedule timer (create instance B) +// T3: instance B fires, blocked trying to obtain L. +// T1: Attempt to stop instance B (set earlyReturn = true) +// T1: Schedule timer (create instance C) +// T4: instance C fires, blocked trying to obtain L. +// T1: Attempt to stop instance C (set earlyReturn = true) +// T1: Schedule timer (create instance D) +// T5: instance D fires, blocked trying to obtain L. +// T1: Release L +// +// Now that T1 has released L, any of the 4 timer instances can take L and +// check earlyReturn. If the timers simply check earlyReturn and then do +// nothing further, then instance D will never early return even though it was +// not requested to stop. If the timers reset earlyReturn before early +// returning, then all but one of the timers will do work when only one was +// expected to. If Job resets earlyReturn when resetting, then all the timers +// will fire (again, when only one was expected to). +// +// To address the above concerns the simplest solution was to give each timer +// its own earlyReturn signal. +type jobInstance struct { + timer Timer + + // Used to inform the timer to early return when it gets stopped while the + // lock the timer tries to obtain when fired is held (T1 is a goroutine that + // tries to cancel the timer and T2 is the goroutine that handles the timer + // firing): + // T1: Obtain the lock, then call Cancel() + // T2: timer fires, and gets blocked on obtaining the lock + // T1: Releases lock + // T2: Obtains lock does unintended work + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using earlyReturn to return early so that once T2 obtains + // the lock, it will see that it is set to true and do nothing further. + earlyReturn *bool +} + +// stop stops the job instance j from firing if it hasn't fired already. If it +// has fired and is blocked at obtaining the lock, earlyReturn will be set to +// true so that it will early return when it obtains the lock. +func (j *jobInstance) stop() { + if j.timer != nil { + j.timer.Stop() + *j.earlyReturn = true + } +} + +// Job represents some work that can be scheduled for execution. The work can +// be safely cancelled when it fires at the same time some "related work" is +// being done. +// +// The term "related work" is defined as some work that needs to be done while +// holding some lock that the timer must also hold while doing some work. +// +// Note, it is not safe to copy a Job as its timer instance creates +// a closure over the address of the Job. +type Job struct { + _ sync.NoCopy + + // The clock used to schedule the backing timer + clock Clock + + // The active instance of a cancellable timer. + instance jobInstance + + // locker is the lock taken by the timer immediately after it fires and must + // be held when attempting to stop the timer. + // + // Must never change after being assigned. + locker sync.Locker + + // fn is the function that will be called when a timer fires and has not been + // signaled to early return. + // + // fn MUST NOT attempt to lock locker. + // + // Must never change after being assigned. + fn func() +} + +// Cancel prevents the Job from executing if it has not executed already. +// +// Cancel requires appropriate locking to be in place for any resources managed +// by the Job. If the Job is blocked on obtaining the lock when Cancel is +// called, it will early return. +// +// Note, t will be modified. +// +// j.locker MUST be locked. +func (j *Job) Cancel() { + j.instance.stop() + + // Nothing to do with the stopped instance anymore. + j.instance = jobInstance{} +} + +// Schedule schedules the Job for execution after duration d. This can be +// called on cancelled or completed Jobs to schedule them again. +// +// Schedule should be invoked only on unscheduled, cancelled, or completed +// Jobs. To be safe, callers should always call Cancel before calling Schedule. +// +// Note, j will be modified. +func (j *Job) Schedule(d time.Duration) { + // Create a new instance. + earlyReturn := false + + // Capture the locker so that updating the timer does not cause a data race + // when a timer fires and tries to obtain the lock (read the timer's locker). + locker := j.locker + j.instance = jobInstance{ + timer: j.clock.AfterFunc(d, func() { + locker.Lock() + defer locker.Unlock() + + if earlyReturn { + // If we reach this point, it means that the timer fired while another + // goroutine called Cancel while it had the lock. Simply return here + // and do nothing further. + earlyReturn = false + return + } + + j.fn() + }), + earlyReturn: &earlyReturn, + } +} + +// NewJob returns a new Job that can be used to schedule f to run in its own +// gorountine. l will be locked before calling f then unlocked after f returns. +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// message = "bar" +// mu.Unlock() +// +// // Output: bar +// +// f MUST NOT attempt to lock l. +// +// l MUST be locked prior to calling the returned job's Cancel(). +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// job.Cancel() +// mu.Unlock() +func NewJob(c Clock, l sync.Locker, f func()) *Job { + return &Job{ + clock: c, + locker: l, + fn: f, + } +} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go new file mode 100644 index 000000000..a82384c49 --- /dev/null +++ b/pkg/tcpip/timer_test.go @@ -0,0 +1,268 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip_test + +import ( + "sync" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + shortDuration = 1 * time.Nanosecond + middleDuration = 100 * time.Millisecond + longDuration = 1 * time.Second +) + +func TestJobReschedule(t *testing.T) { + var clock tcpip.StdClock + var wg sync.WaitGroup + var lock sync.Mutex + + for i := 0; i < 2; i++ { + wg.Add(1) + + go func() { + lock.Lock() + // Assigning a new timer value updates the timer's locker and function. + // This test makes sure there is no data race when reassigning a timer + // that has an active timer (even if it has been stopped as a stopped + // timer may be blocked on a lock before it can check if it has been + // stopped while another goroutine holds the same lock). + job := tcpip.NewJob(&clock, &lock, func() { + wg.Done() + }) + job.Schedule(shortDuration) + lock.Unlock() + }() + } + wg.Wait() +} + +func TestJobExecution(t *testing.T) { + t.Parallel() + + var clock tcpip.StdClock + var lock sync.Mutex + ch := make(chan struct{}) + + job := tcpip.NewJob(&clock, &lock, func() { + ch <- struct{}{} + }) + job.Schedule(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerResetFromLongDuration(t *testing.T) { + t.Parallel() + + var clock tcpip.StdClock + var lock sync.Mutex + ch := make(chan struct{}) + + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(middleDuration) + + lock.Lock() + job.Cancel() + lock.Unlock() + + job.Schedule(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestJobRescheduleFromShortDuration(t *testing.T) { + t.Parallel() + + var clock tcpip.StdClock + var lock sync.Mutex + ch := make(chan struct{}) + + lock.Lock() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() + lock.Unlock() + + // Wait for timer to fire if it wasn't correctly stopped. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration): + } + + job.Schedule(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestJobImmediatelyCancel(t *testing.T) { + t.Parallel() + + var clock tcpip.StdClock + var lock sync.Mutex + ch := make(chan struct{}) + + for i := 0; i < 1000; i++ { + lock.Lock() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() + lock.Unlock() + } + + // Wait for timer to fire if it wasn't correctly stopped. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration): + } +} + +func TestJobCancelledRescheduleWithoutLock(t *testing.T) { + t.Parallel() + + var clock tcpip.StdClock + var lock sync.Mutex + ch := make(chan struct{}) + + lock.Lock() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() + lock.Unlock() + + for i := 0; i < 10; i++ { + job.Schedule(middleDuration) + + lock.Lock() + // Sleep until the timer fires and gets blocked trying to take the lock. + time.Sleep(middleDuration * 2) + job.Cancel() + lock.Unlock() + } + + // Wait for double the duration so timers that weren't correctly stopped can + // fire. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration * 2): + } +} + +func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { + t.Parallel() + + var clock tcpip.StdClock + var lock sync.Mutex + ch := make(chan struct{}) + + lock.Lock() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + for i := 0; i < 10; i++ { + // Sleep until the timer fires and gets blocked trying to take the lock. + time.Sleep(middleDuration) + job.Cancel() + job.Schedule(shortDuration) + } + lock.Unlock() + + // Wait for double the duration for the last timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestManyJobReschedulesUnderLock(t *testing.T) { + t.Parallel() + + var clock tcpip.StdClock + var lock sync.Mutex + ch := make(chan struct{}) + + lock.Lock() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + for i := 0; i < 10; i++ { + job.Cancel() + job.Schedule(shortDuration) + } + lock.Unlock() + + // Wait for double the duration for the last timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9254c3dea..7e5c79776 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,26 +23,18 @@ go_library( "icmp_packet_list.go", "protocol.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "icmp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 043467519..bd6f49eb8 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,12 +15,11 @@ package icmp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -31,9 +30,6 @@ type icmpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } type endpointState int @@ -58,6 +54,7 @@ type endpoint struct { // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -77,6 +74,9 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -90,9 +90,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: stateInitial, + uniqueID: s.UniqueID(), }, nil } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -100,7 +111,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -126,9 +137,8 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { - return e.stack.IPTables(), nil +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner } // Read reads data from the endpoint. This method does not block if @@ -274,24 +284,22 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC + nicID := to.NIC if e.BindNICID != 0 { - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.BindNICID + nicID = e.BindNICID } - toCopy := *to - to = &toCopy - netProto, err := e.checkV4Mapped(to, true) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + // Find the endpoint. + r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -316,7 +324,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c switch e.NetProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl) + err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) case header.IPv6ProtocolNumber: err = send6(route, e.ID.LocalPort, v, e.ttl) @@ -336,23 +344,43 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(o) - e.mu.Unlock() + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil } + return nil +} +// SetSockOptBool sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return nil } // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + } return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -375,39 +403,39 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + case tcpip.TTLOption: + e.rcvMu.Lock() + v := int(e.ttl) + e.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.TTLOption: - e.rcvMu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.rcvMu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + }) + pkt.Owner = owner - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -422,10 +450,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + pkt.Data = data.ToVectorisedView() + if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -433,9 +463,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + }) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) @@ -447,26 +479,22 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err dataVV := data.ToVectorisedView() icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) + pkt.Data = dataVV if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, dataVV, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - return 0, tcpip.ErrNoRoute - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) + if err != nil { + return tcpip.FullAddress{}, 0, err } - - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -479,31 +507,32 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC localPort := uint16(0) switch e.state { + case stateInitial: case stateBound, stateConnected: localPort = e.ID.LocalPort if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.BindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -520,14 +549,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id e.route = r.Clone() - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.state = stateConnected @@ -578,18 +607,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) switch err { case nil: return true, nil @@ -610,7 +639,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -714,19 +743,23 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(vv.First()) - if h.Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(vv.First()) - if h.Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return @@ -753,19 +786,21 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &icmpPacket{ + packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, }, } - pkt.data = vv.Clone(pkt.views[:]) + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader().View().ToVectorisedView() + packet.data.Append(pkt.Data) - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() @@ -776,7 +811,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't @@ -798,3 +833,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index bfb16f7c3..74ef6541e 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,20 +104,36 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. + // + // Right now, the Parse() method is tied to enabled protocols passed into + // stack.New. This works for UDP and TCP, but we handle ICMP traffic even + // when netstack users don't pass ICMP as a supported protocol. + return false +} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 8ea2e6ee5..b989b1209 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,25 +22,16 @@ go_library( "endpoint_state.go", "packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 73cdaa265..1b03ad6bb 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,12 +25,12 @@ package packet import ( - "sync" + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -41,14 +41,13 @@ type packet struct { // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + // packetInfo holds additional information like the protocol + // of the packet etc. + packetInfo tcpip.LinkPacketInfo } // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal @@ -77,10 +76,17 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID + + // lastErrorMu protects lastError. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` } // NewEndpoint returns a new packet endpoint. @@ -97,12 +103,28 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb sndBufSize: 32 * 1024, } + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.rcvBufSizeMax = rs.Default + } + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { return nil, err } return ep, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (ep *endpoint) Abort() { + ep.Close() +} + // Close implements tcpip.Endpoint.Close. func (ep *endpoint) Close() { ep.mu.Lock() @@ -125,19 +147,15 @@ func (ep *endpoint) Close() { } ep.closed = true + ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (iptables.IPTables, error) { - return ep.stack.IPTables(), nil -} - -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.PacketEndpoint.ReadPacket. +func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -162,11 +180,20 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes *addr = packet.senderAddr } + if info != nil { + *info = packet.packetInfo + } + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil } +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return ep.ReadPacket(addr, nil) +} + func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - // TODO(b/129292371): Implement. + // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } @@ -216,7 +243,27 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." // - packet(7). - return tcpip.ErrNotSupported + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.bound && ep.boundNIC == addr.NIC { + // If the NIC being bound is the same then just return success. + return nil + } + + // Unregister endpoint with all the nics. + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.bound = false + + // Bind endpoint to receive packets from specific interface. + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + return err + } + + ep.bound = true + ep.boundNIC = addr.NIC + + return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. @@ -251,26 +298,119 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrNotSupported + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := ep.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + ep.mu.Lock() + ep.sndBufSizeMax = v + ep.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := ep.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + ep.rcvMu.Lock() + ep.rcvBufSizeMax = v + ep.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (ep *endpoint) takeLastError() *tcpip.Error { + ep.lastErrorMu.Lock() + defer ep.lastErrorMu.Unlock() + + err := ep.lastError + ep.lastError = nil + return err } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return ep.takeLastError() + } return tcpip.ErrNotSupported } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrNotSupported +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + ep.rcvMu.Lock() + if !ep.rcvList.Empty() { + p := ep.rcvList.Front() + v = p.data.Size() + } + ep.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSizeMax + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, ethHeader buffer.View) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -292,45 +432,73 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(b/129292371): Return network protocol. - if len(ethHeader) > 0 { + // TODO(gvisor.dev/issue/173): Return network protocol. + if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. - hdr := header.Ethernet(ethHeader) + hdr := header.Ethernet(pkt.LinkHeader().View()) packet.senderAddr = tcpip.FullAddress{ - NIC: nicid, + NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ - NIC: nicid, + NIC: nicID, Addr: tcpip.Address(localAddr), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } if ep.cooked { // Cooked packets can simply be queued. - packet.data = vv.Clone(packet.views[:]) + switch pkt.PktType { + case tcpip.PacketHost: + packet.data = pkt.Data + case tcpip.PacketOutgoing: + // Strip Link Header. + var combinedVV buffer.VectorisedView + if v := pkt.NetworkHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } + if v := pkt.TransportHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } + combinedVV.Append(pkt.Data) + packet.data = combinedVV + default: + panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + } + } else { // Raw packets need their ethernet headers prepended before // queueing. - if len(ethHeader) == 0 { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, + var linkHeader buffer.View + if pkt.PktType != tcpip.PacketOutgoing { + if pkt.LinkHeader().View().IsEmpty() { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - ethHeader = buffer.View(fakeHeader) + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } else { + packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } - combinedVV := buffer.View(ethHeader).ToVectorisedView() - combinedVV.Append(vv) - packet.data = combinedVV.Clone(packet.views[:]) } - packet.timestampNS = ep.stack.NowNanoseconds() + packet.timestampNS = ep.stack.Clock().NowNanoseconds() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() @@ -361,3 +529,5 @@ func (ep *endpoint) Info() tcpip.EndpointInfo { func (ep *endpoint) Stats() tcpip.EndpointStats { return &ep.stats } + +func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..e2fa96d17 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() { panic(*err) } } + +// saveLastError is invoked by stateify. +func (ep *endpoint) saveLastError() string { + if ep.lastError == nil { + return "" + } + + return ep.lastError.String() +} + +// loadLastError is invoked by stateify. +func (ep *endpoint) loadLastError(s string) { + if s == "" { + return + } + + ep.lastError = tcpip.StringToError(s) +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 4af49218c..2eab09088 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,26 +23,17 @@ go_library( "protocol.go", "raw_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/packet", "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "raw_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 308f10d24..edc2b5b61 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,12 +26,12 @@ package raw import ( - "sync" + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -42,10 +42,6 @@ type rawPacket struct { // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -67,25 +63,30 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` rcvList rawPacketList - rcvBufSizeMax int `state:".(int)"` rcvBufSize int + rcvBufSizeMax int `state:".(int)"` rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - connected bool - bound bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + connected bool + bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // NewEndpoint returns a raw endpoint for the given protocols. @@ -94,7 +95,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if netProto != header.IPv4ProtocolNumber { + if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } @@ -106,8 +107,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, + sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, + } + + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + e.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + e.rcvBufSizeMax = rs.Default } // Unassociated endpoints are write-only and users call Write() with IP @@ -126,6 +139,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (e *endpoint) Close() { e.mu.Lock() @@ -160,17 +178,12 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { - return e.stack.IPTables(), nil +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner } // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -200,6 +213,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // Write implements tcpip.Endpoint.Write. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // We can create, but not write to, unassociated IPv6 endpoints. + if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + return 0, nil, tcpip.ErrInvalidOptionValue + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -243,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -304,12 +322,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrNoRoute } - // We don't support IPv6 yet, so this has to be an IPv4 address. - if len(opts.To.Addr) != header.IPv4AddressSize { - e.mu.RUnlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) @@ -339,21 +351,26 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch e.NetProto { - case header.IPv4ProtocolNumber: - if !e.associated { - if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { - return 0, nil, err - } - break + if e.hdrIncluded { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) + if err := route.WriteHeaderIncludedPacket(pkt); err != nil { + return 0, nil, err } - hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + } else { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()), + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) + pkt.Owner = e.owner + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: e.TransProto, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, pkt); err != nil { return 0, nil, err } - - default: - return 0, nil, tcpip.ErrUnknownProtocol } return int64(len(payloadBytes)), nil, nil @@ -378,11 +395,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // We don't support IPv6 yet. - if len(addr.Addr) != header.IPv4AddressSize { - return tcpip.ErrInvalidEndpointState - } - nic := addr.NIC if e.bound { if e.BindNICID == 0 { @@ -448,14 +460,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - // Callers must provide an IPv4 address or no network address (for - // binding to a NIC, but not an address). - if len(addr.Addr) != 0 && len(addr.Addr) != 4 { - return tcpip.ErrInvalidEndpointState - } - // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } @@ -505,16 +511,101 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + e.mu.Lock() + e.sndBufSizeMax = v + e.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + e.rcvMu.Lock() + e.rcvBufSizeMax = v + e.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -528,7 +619,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { case tcpip.SendBufferSizeOption: e.mu.Lock() - v := e.sndBufSize + v := e.sndBufSizeMax e.mu.Unlock() return v, nil @@ -538,32 +629,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - } - - return -1, tcpip.ErrUnknownProtocolOption -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: - return tcpip.ErrUnknownProtocolOption + return -1, tcpip.ErrUnknownProtocolOption } } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() @@ -600,21 +683,33 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &rawPacket{ + packet := &rawPacket{ senderAddr: tcpip.FullAddress{ NIC: route.NICID(), Addr: route.RemoteAddress, }, } - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - pkt.data = combinedVV.Clone(pkt.views[:]) - pkt.timestampNS = e.stack.NowNanoseconds() - - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() - + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + var combinedVV buffer.VectorisedView + if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(network)+len(transport)) + headers = append(headers, network...) + headers = append(headers, transport...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() + } + combinedVV.Append(pkt.Data) + packet.data = combinedVV + packet.timestampNS = e.stack.Clock().NowNanoseconds() + + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. @@ -641,3 +736,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f1dbc6f91..234fb95ce 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -16,18 +15,35 @@ go_template_instance( }, ) +go_template_instance( + name = "tcp_endpoint_list", + out = "tcp_endpoint_list.go", + package = "tcp", + prefix = "endpoint", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*endpoint", + "Linker": "*endpoint", + }, +) + go_library( name = "tcp", srcs = [ "accept.go", "connect.go", + "connect_unsafe.go", "cubic.go", "cubic_state.go", + "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", "protocol.go", + "rack.go", + "rack_state.go", "rcv.go", + "rcv_state.go", "reno.go", "sack.go", "sack_scoreboard.go", @@ -35,55 +51,49 @@ go_library( "segment_heap.go", "segment_queue.go", "segment_state.go", + "segment_unsafe.go", "snd.go", "snd_state.go", + "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", - "//pkg/tmutex", "//pkg/waiter", "@com_github_google_btree//:go_default_library", ], ) -filegroup( - name = "autogen", - srcs = [ - "tcp_segment_list.go", - ], - visibility = ["//:sandbox"], -) - go_test( - name = "tcp_test", - size = "small", + name = "tcp_x_test", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", + "tcp_rack_test.go", "tcp_sack_test.go", "tcp_test.go", "tcp_timestamp_test.go", ], - # FIXME(b/68809571) - tags = ["flaky"], + shard_count = 10, deps = [ ":tcp", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", @@ -96,6 +106,25 @@ go_test( "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp/testing/context", + "//pkg/test/testutil", "//pkg/waiter", ], ) + +go_test( + name = "rcv_test", + size = "small", + srcs = ["rcv_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + ], +) + +go_test( + name = "tcp_test", + size = "small", + srcs = ["timer_test.go"], + library = ":tcp", + deps = ["//pkg/sleep"], +) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 65c346046..b706438bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -17,13 +17,14 @@ package tcp import ( "crypto/sha1" "encoding/binary" + "fmt" "hash" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -48,17 +49,14 @@ const ( // timestamp and the current timestamp. If the difference is greater // than maxTSDiff, the cookie is expired. maxTSDiff = 2 -) -var ( - // SynRcvdCountThreshold is the global maximum number of connections - // that are allowed to be in SYN-RCVD state before TCP starts using SYN - // cookies to accept connections. - // - // It is an exported variable only for testing, and should not otherwise - // be used by importers of this package. + // SynRcvdCountThreshold is the default global maximum number of + // connections that are allowed to be in SYN-RCVD state before TCP + // starts using SYN cookies to accept connections. SynRcvdCountThreshold uint64 = 1000 +) +var ( // mssTable is a slice containing the possible MSS values that we // encode in the SYN cookie with two bits. mssTable = []uint16{536, 1300, 1440, 1460} @@ -73,29 +71,42 @@ func encodeMSS(mss uint16) uint32 { return 0 } -// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is -// protected by a mutex so that we can increment only when it's guaranteed not -// to go above a threshold. -var synRcvdCount struct { - sync.Mutex - value uint64 - pending sync.WaitGroup -} - // listenContext is used by a listening endpoint to store state used while // listening for connections. This struct is allocated by the listen goroutine // and must not be accessed or have its methods called concurrently as they // may mutate the stored objects. type listenContext struct { - stack *stack.Stack - rcvWnd seqnum.Size - nonce [2][sha1.BlockSize]byte + stack *stack.Stack + + // synRcvdCount is a reference to the stack level synRcvdCount. + synRcvdCount *synRcvdCounter + + // rcvWnd is the receive window that is sent by this listening context + // in the initial SYN-ACK. + rcvWnd seqnum.Size + + // nonce are random bytes that are initialized once when the context + // is created and used to seed the hash function when generating + // the SYN cookie. + nonce [2][sha1.BlockSize]byte + + // listenEP is a reference to the listening endpoint associated with + // this context. Can be nil if the context is created by the forwarder. listenEP *endpoint + // hasherMu protects hasher. hasherMu sync.Mutex - hasher hash.Hash - v6only bool + // hasher is the hash function used to generate a SYN cookie. + hasher hash.Hash + + // v6Only is true if listenEP is a dual stack socket and has the + // IPV6_V6ONLY option set. + v6Only bool + + // netProto indicates the network protocol(IPv4/v6) for the listening + // endpoint. netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed // by the listening endpoint's worker goroutine. // @@ -114,55 +125,22 @@ func timeStamp() uint32 { return uint32(time.Now().Unix()>>6) & tsMask } -// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD -// state. It succeeds if the increment doesn't make the count go beyond the -// threshold, and fails otherwise. -func incSynRcvdCount() bool { - synRcvdCount.Lock() - - if synRcvdCount.value >= SynRcvdCountThreshold { - synRcvdCount.Unlock() - return false - } - - synRcvdCount.pending.Add(1) - synRcvdCount.value++ - - synRcvdCount.Unlock() - return true -} - -// decSynRcvdCount atomically decrements the global number of endpoints in -// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount -// succeeded. -func decSynRcvdCount() { - synRcvdCount.Lock() - - synRcvdCount.value-- - synRcvdCount.pending.Done() - synRcvdCount.Unlock() -} - -// synCookiesInUse() returns true if the synRcvdCount is greater than -// SynRcvdCountThreshold. -func synCookiesInUse() bool { - synRcvdCount.Lock() - v := synRcvdCount.value - synRcvdCount.Unlock() - return v >= SynRcvdCountThreshold -} - // newListenContext creates a new listen context. -func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ stack: stk, rcvWnd: rcvWnd, hasher: sha1.New(), - v6only: v6only, + v6Only: v6Only, netProto: netProto, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } + p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) + if !ok { + panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) + } + l.synRcvdCount = p.SynRcvdCounter() rand.Read(l.nonce[0][:]) rand.Read(l.nonce[1][:]) @@ -221,85 +199,119 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { // Create a new endpoint. netProto := l.netProto if netProto == 0 { netProto = s.route.NetProto } - n := newEndpoint(l.stack, netProto, nil) - n.v6only = l.v6only + n := newEndpoint(l.stack, netProto, queue) + n.v6only = l.v6Only n.ID = s.id n.boundNICID = s.route.NICID() n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) - n.amss = mssForRoute(&n.route) + n.amss = calculateAdvertisedMSS(n.userMSS, n.route) + n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts) n.initGSO() - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { - n.Close() - return nil, err - } - - n.isRegistered = true - - // Create sender and receiver. - // - // The receiver at least temporarily has a zero receive window scale, - // but the caller may change it (before starting the protocol loop). - n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) - n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize())) // Bootstrap the auto tuning algorithm. Starting at zero will result in // a large step function on the first window adjustment causing the // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - return n, nil + return n } -// createEndpoint creates a new endpoint in connected state and then performs -// the TCP 3-way handshake. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +// createEndpointAndPerformHandshake creates a new endpoint in connected state +// and then performs the TCP 3-way handshake. +// +// The new endpoint is returned with e.mu held. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) - if err != nil { - return nil, err - } + isn := generateSecureISN(s.id, l.stack.Seed()) + ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + ep.mu.Lock() + ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. + deferAccept := time.Duration(0) if l.listenEP != nil { l.listenEP.mu.Lock() - if l.listenEP.state != StateListen { + if l.listenEP.EndpointState() != StateListen { + l.listenEP.mu.Unlock() + // Ensure we release any registrations done by the newly + // created endpoint. + ep.mu.Unlock() + ep.Close() + return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + l.listenEP.propagateInheritableOptionsLocked(ep) + + if !ep.reserveTupleLocked() { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + + return nil, tcpip.ErrConnectionAborted + } + + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } - // Perform the 3-way handshake. - h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) + // Register new endpoint so that packets are routed to it. + if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } - h.resetToSynRcvd(cookie, irs, opts) + ep.drainClosingSegmentQueue() + + return nil, err + } + + ep.isRegistered = true + + // Perform the 3-way handshake. + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { + ep.mu.Unlock() ep.Close() + ep.notifyAborted() + if l.listenEP != nil { l.removePendingEndpoint(ep) } + + ep.drainClosingSegmentQueue() + return nil, err } - ep.mu.Lock() - ep.stack.Stats().TCP.CurrentEstablished.Increment() - ep.state = StateEstablished - ep.mu.Unlock() + ep.isConnectNotified = true // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -333,23 +345,78 @@ func (l *listenContext) closeAllPendingEndpoints() { } // deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// endpoint has transitioned out of the listen state, the new endpoint is closed -// instead. +// endpoint has transitioned out of the listen state (acceptedChan is nil), +// the new endpoint is closed instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.state e.pendingAccepted.Add(1) - defer e.pendingAccepted.Done() - acceptedChan := e.acceptedChan e.mu.Unlock() - if state == StateListen { - acceptedChan <- n - e.waiterQueue.Notify(waiter.EventIn) - } else { - n.Close() + defer e.pendingAccepted.Done() + + e.acceptMu.Lock() + for { + if e.acceptedChan == nil { + e.acceptMu.Unlock() + n.notifyProtocolGoroutine(notifyReset) + return + } + select { + case e.acceptedChan <- n: + e.acceptMu.Unlock() + e.waiterQueue.Notify(waiter.EventIn) + return + default: + e.acceptCond.Wait() + } } } +// propagateInheritableOptionsLocked propagates any options set on the listening +// endpoint to the newly created endpoint. +// +// Precondition: e.mu and n.mu must be held. +func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { + n.userTimeout = e.userTimeout + n.portFlags = e.portFlags + n.boundBindToDevice = e.boundBindToDevice + n.boundPortFlags = e.boundPortFlags + n.userMSS = e.userMSS +} + +// reserveTupleLocked reserves an accepted endpoint's tuple. +// +// Preconditions: +// * propagateInheritableOptionsLocked has been called. +// * e.mu is held. +func (e *endpoint) reserveTupleLocked() bool { + dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + if !e.stack.ReserveTuple( + e.effectiveNetProtos, + ProtocolNumber, + e.ID.LocalAddress, + e.ID.LocalPort, + e.boundPortFlags, + e.boundBindToDevice, + dest, + ) { + return false + } + + e.isPortReserved = true + e.boundDest = dest + return true +} + +// notifyAborted wakes up any waiters on registered, but not accepted +// endpoints. +// +// This is strictly not required normally as a socket that was never accepted +// can't really have any registered waiters except when stack.Wait() is called +// which waits for all registered endpoints to stop and expects an EventHUp. +func (e *endpoint) notifyAborted() { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. @@ -357,53 +424,68 @@ func (e *endpoint) deliverAccepted(n *endpoint) { // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer decSynRcvdCount() - defer e.decSynRcvdCount() + defer ctx.synRcvdCount.dec() + defer func() { + e.mu.Lock() + e.decSynRcvdCount() + e.mu.Unlock() + }() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts) + + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) + n.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } func (e *endpoint) incSynRcvdCount() bool { - e.mu.Lock() - if e.synRcvdCount >= cap(e.acceptedChan) { - e.mu.Unlock() - return false + e.acceptMu.Lock() + canInc := e.synRcvdCount < cap(e.acceptedChan) + e.acceptMu.Unlock() + if canInc { + e.synRcvdCount++ } - e.synRcvdCount++ - e.mu.Unlock() - return true + return canInc } func (e *endpoint) decSynRcvdCount() { - e.mu.Lock() e.synRcvdCount-- - e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { - e.mu.Lock() - if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { - e.mu.Unlock() - return true - } - e.mu.Unlock() - return false + e.acceptMu.Lock() + full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + e.acceptMu.Unlock() + return full } // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { - switch s.flags { - case header.TCPFlagSyn: + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() + if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + // If the endpoint is shutdown, reply with reset. + // + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + replyWithReset(s, e.sendTOS, e.ttl) + return + } + + switch { + case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) - if incSynRcvdCount() { + if ctx.synRcvdCount.inc() { // Only handle the syn if the following conditions hold // - accept queue is not full. // - number of connections in synRcvd state is less than the @@ -413,7 +495,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. return } - decSynRcvdCount() + ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -430,23 +512,33 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) // Send SYN without window scaling because we currently - // dont't encode this information in the cookie. + // don't encode this information in the cookie. // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(timeStampOffset()), + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: calculateAdvertisedMSS(e.userMSS, s.route), } - e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + }, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } - case header.TCPFlagAck: + case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -459,7 +551,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } - if !synCookiesInUse() { + if !ctx.synRcvdCount.synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // // Send a reset as this is an ACK for which there is no // half open connections and we are not using cookies // yet. @@ -467,10 +567,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s) + replyWithReset(s, e.sendTOS, e.ttl) return } + iss := s.ackNumber - 1 + irs := s.sequenceNumber - 1 + // Since SYN cookies are in use this is potentially an ACK to a // SYN-ACK we sent but don't have a half open connection state // as cookies are being used to protect against a potential SYN @@ -481,7 +584,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // when under a potential syn flood attack. // // Validate the cookie. - data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) + data, ok := ctx.isCookieValid(s.id, iss, irs) if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -506,13 +609,35 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) - if err != nil { + n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + + n.mu.Lock() + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + + if !n.reserveTupleLocked() { + n.mu.Unlock() + n.Close() + + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } + n.isRegistered = true + // clear the tsOffset for the newly created // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was @@ -520,8 +645,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. - n.stack.Stats().TCP.CurrentEstablished.Increment() - n.state = StateEstablished + n.isConnectNotified = true + n.transitionToStateEstablishedLocked(&handshake{ + ep: n, + iss: iss, + ackNum: irs + 1, + rcvWnd: seqnum.Size(n.initialReceiveWindow()), + sndWnd: s.window, + rcvWndScale: e.rcvWndScaleForHandshake(), + sndWndScale: rcvdSynOptions.WS, + mss: rcvdSynOptions.MSS, + }) // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -532,6 +666,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + n.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } @@ -540,16 +678,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() - v6only := e.v6only - e.mu.Unlock() - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) + v6Only := e.v6only + ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections // to the endpoint. - e.mu.Lock() - e.state = StateClose + e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() @@ -562,15 +698,20 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { } e.mu.Unlock() + e.drainClosingSegmentQueue() + // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { - switch index, _ := s.Fetch(true); index { + e.mu.Unlock() + index, _ := s.Fetch(true) + e.mu.Lock() + switch index { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { @@ -583,7 +724,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { s.decRef() } close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } case wakerForNewSegment: diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 790e89cc3..290172ac9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,13 +15,15 @@ package tcp import ( - "sync" + "encoding/binary" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -59,6 +61,9 @@ const ( ) // handshake holds the state used during a TCP 3-way handshake. +// +// NOTE: handshake.ep.mu is held during handshake processing. It is released if +// we are going to block and reacquired when we start processing an event. type handshake struct { ep *endpoint state handshakeState @@ -84,32 +89,38 @@ type handshake struct { // rcvWndScale is the receive window scale, as defined in RFC 1323. rcvWndScale int -} -func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { - rcvWndScale := ep.rcvWndScaleForHandshake() + // startTime is the time at which the first SYN/SYN-ACK was sent. + startTime time.Time - // Round-down the rcvWnd to a multiple of wndScale. This ensures that the - // window offered in SYN won't be reduced due to the loss of precision if - // window scaling is enabled after the handshake. - rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale) + // deferAccept if non-zero will drop the final ACK for a passive + // handshake till an ACK segment with data is received or the timeout is + // hit. + deferAccept time.Duration - // Ensure we can always accept at least 1 byte if the scale specified - // was too high for the provided rcvWnd. - if rcvWnd == 0 { - rcvWnd = 1 - } + // acked is true if the the final ACK for a 3-way handshake has + // been received. This is required to stop retransmitting the + // original SYN-ACK when deferAccept is enabled. + acked bool +} +func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { h := handshake{ ep: ep, active: true, rcvWnd: rcvWnd, - rcvWndScale: int(rcvWndScale), + rcvWndScale: ep.rcvWndScaleForHandshake(), } h.resetState() return h } +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) + h.resetToSynRcvd(isn, irs, opts, deferAccept) + return h +} + // FindWndScale determines the window scale to use for the given maximum window // size. func FindWndScale(wnd seqnum.Size) int { @@ -139,7 +150,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -154,7 +190,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -162,9 +198,8 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.ackNum = irs + 1 h.mss = opts.MSS h.sndWndScale = opts.WS - h.ep.mu.Lock() - h.ep.state = StateSynRecv - h.ep.mu.Unlock() + h.deferAccept = deferAccept + h.ep.setEndpointState(StateSynRecv) } // checkAck checks if the ACK number, if present, of a segment received during @@ -191,6 +226,12 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK + // was acceptable then signal the user "error: connection reset", drop + // the segment, enter CLOSED state, delete TCB, and return." + h.ep.workerCleanup = true + // Although the RFC above calls out ECONNRESET, Linux actually returns + // ECONNREFUSED here so we do as well. return tcpip.ErrConnectionRefused } return nil @@ -225,6 +266,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // and the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted + + h.ep.transitionToStateEstablishedLocked(h) + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil } @@ -233,26 +277,33 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd - h.ep.mu.Lock() - h.ep.state = StateSynRecv ttl := h.ep.ttl - h.ep.mu.Unlock() + amss := h.ep.amss + h.ep.setEndpointState(StateSynRecv) synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), // We only send SACKPermitted if the other side indicated it // permits SACK. This is not explicitly defined in the RFC but // this is the behaviour implemented by Linux. SACKPermitted: rcvSynOpts.SACKPermitted, - MSS: h.ep.amss, + MSS: amss, } if ttl == 0 { ttl = s.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, tcpFields{ + id: h.ep.ID, + ttl: ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) return nil } @@ -272,6 +323,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } + // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a + // sequence number outside of the window causes an ACK with the proper seq + // number and "After sending the acknowledgment, drop the unacceptable + // segment and return." + if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + return nil + } + if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { // We received two SYN segments with different sequence // numbers, so we reset this and restart the whole @@ -292,17 +352,33 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { WS: h.rcvWndScale, TS: h.ep.sendTSOk, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) return nil } // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { + // If deferAccept is not zero and this is a bare ACK and the + // timeout is not hit then drop the ACK. + if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + h.acked = true + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. @@ -316,6 +392,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + + h.ep.transitionToStateEstablishedLocked(h) + + // If the segment has data then requeue it for the receiver + // to process it again once main loop is started. + if s.data.Size() > 0 { + s.incRef() + h.ep.enqueueSegment(s) + } return nil } @@ -401,7 +486,12 @@ func (h *handshake) resolveRoute() *tcpip.Error { } if n¬ifyDrain != 0 { close(h.ep.drainDone) + h.ep.mu.Unlock() <-h.ep.undrain + h.ep.mu.Lock() + } + if n¬ifyError != 0 { + return h.ep.takeLastError() } } @@ -418,12 +508,11 @@ func (h *handshake) execute() *tcpip.Error { } } + h.startTime = time.Now() // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, func() { - resendWaker.Assert() - }) + rt := time.AfterFunc(timeOut, resendWaker.Assert) defer rt.Stop() // Set up the wakers. @@ -442,13 +531,13 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.amss = mssForRoute(&h.ep.route) + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } @@ -465,21 +554,52 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + h.ep.sendSynTCP(&h.ep.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) for h.state != handshakeCompleted { - switch index, _ := s.Fetch(true); index { + h.ep.mu.Unlock() + index, _ := s.Fetch(true) + h.ep.mu.Lock() + switch index { + case wakerForResend: timeOut *= 2 - if timeOut > 60*time.Second { + if timeOut > MaxRTO { return tcpip.ErrTimeout } rt.Reset(timeOut) - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + // Resend the SYN/SYN-ACK only if the following conditions hold. + // - It's an active handshake (deferAccept does not apply) + // - It's a passive handshake and we have not yet got the final-ACK. + // - It's a passive handshake and we got an ACK but deferAccept is + // enabled and we are now past the deferAccept duration. + // The last is required to provide a way for the peer to complete + // the connection with another ACK or data (as ACKs are never + // retransmitted on their own). + if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + h.ep.sendSynTCP(&h.ep.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) + } case wakerForNotification: n := h.ep.fetchNotifications() - if n¬ifyClose != 0 { + if (n¬ifyClose)|(n¬ifyAbort) != 0 { return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -495,7 +615,12 @@ func (h *handshake) execute() *tcpip.Error { } } close(h.ep.drainDone) + h.ep.mu.Unlock() <-h.ep.undrain + h.ep.mu.Lock() + } + if n¬ifyError != 0 { + return h.ep.takeLastError() } case wakerForNewSegment: @@ -519,17 +644,17 @@ func parseSynSegmentOptions(s *segment) header.TCPSynOptions { var optionPool = sync.Pool{ New: func() interface{} { - return make([]byte, maxOptionSize) + return &[maxOptionSize]byte{} }, } func getOptions() []byte { - return optionPool.Get().([]byte) + return (*optionPool.Get().(*[maxOptionSize]byte))[:] } func putOptions(options []byte) { // Reslice to full capacity. - optionPool.Put(options[0:cap(options)]) + optionPool.Put(optionsToArray(options)) } func makeSynOptions(opts header.TCPSynOptions) []byte { @@ -585,18 +710,33 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { - options := makeSynOptions(opts) +// tcpFields is a struct to carry different parameters required by the +// send*TCP variant functions below. +type tcpFields struct { + id stack.TransportEndpointID + ttl uint8 + tos uint8 + flags byte + seq seqnum.Value + ack seqnum.Value + rcvWnd seqnum.Size + opts []byte + txHash uint32 +} + +func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error { + tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. - if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil { + if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { e.stats.SendErrors.SynSendToNetworkFailed.Increment() } - putOptions(options) + putOptions(tf.opts) return nil } -func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { + tf.txHash = e.txHash + if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() return err } @@ -604,26 +744,21 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu return nil } -func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDescriptor, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { - optLen := len(opts) - hdr := &d.Hdr - packetSize := d.Size - off := d.Off - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) +func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { + optLen := len(tf.opts) + tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) tcp.Encode(&header.TCPFields{ - SrcPort: id.LocalPort, - DstPort: id.RemotePort, - SeqNum: uint32(seq), - AckNum: uint32(ack), + SrcPort: tf.id.LocalPort, + DstPort: tf.id.RemotePort, + SeqNum: uint32(tf.seq), + AckNum: uint32(tf.ack), DataOffset: uint8(header.TCPMinimumSize + optLen), - Flags: flags, - WindowSize: uint16(rcvWnd), + Flags: tf.flags, + WindowSize: uint16(tf.rcvWnd), }) - copy(tcp[header.TCPMinimumSize:], opts) + copy(tcp[header.TCPMinimumSize:], tf.opts) - length := uint16(hdr.UsedLength() + packetSize) - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We @@ -632,41 +767,53 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDe // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(data, xsum, off, packetSize) + xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } - } -func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + // We need to shallow clone the VectorisedView here as ReadToView will + // split the VectorisedView and Trim underlying views as it splits. Not + // doing the clone here will cause the underlying views of data itself + // to be altered. + data = data.Clone(nil) + + optLen := len(tf.opts) + if tf.rcvWnd > 0xffff { + tf.rcvWnd = 0xffff } mss := int(gso.MSS) n := (data.Size() + mss - 1) / mss - hdrs := stack.NewPacketDescriptors(n, header.TCPMinimumSize+int(r.MaxHeaderLength())+optLen) - size := data.Size() - off := 0 + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + var pkts stack.PacketBufferList for i := 0; i < n; i++ { packetSize := mss if packetSize > size { packetSize = size } size -= packetSize - hdrs[i].Off = off - hdrs[i].Size = packetSize - buildTCPHdr(r, id, &hdrs[i], data, flags, seq, ack, rcvWnd, opts, gso) - off += packetSize - seq = seq.Add(seqnum.Size(packetSize)) - } - if ttl == 0 { - ttl = r.DefaultTTL() - } - sent, err := r.WritePackets(gso, hdrs, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrSize, + }) + pkt.Hash = tf.txHash + pkt.Owner = owner + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() + data.ReadToVV(&pkt.Data, packetSize) + buildTCPHdr(r, tf, pkt, gso) + tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkts.PushBack(pkt) + } + + if tf.ttl == 0 { + tf.ttl = r.DefaultTTL() + } + sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) if err != nil { r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) } @@ -676,32 +823,33 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + optLen := len(tf.opts) + if tf.rcvWnd > 0xffff { + tf.rcvWnd = 0xffff } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { - return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) + return sendTCPBatch(r, tf, data, gso, owner) } - d := &stack.PacketDescriptor{ - Hdr: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - Off: 0, - Size: data.Size(), - } - buildTCPHdr(r, id, d, data, flags, seq, ack, rcvWnd, opts, gso) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, + Data: data, + }) + pkt.Hash = tf.txHash + pkt.Owner = owner + buildTCPHdr(r, tf, pkt, gso) - if ttl == 0 { - ttl = r.DefaultTTL() + if tf.ttl == 0 { + tf.ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } r.Stats().TCP.SegmentsSent.Increment() - if (flags & header.TCPFlagRst) != 0 { + if (tf.flags & header.TCPFlagRst) != 0 { r.Stats().TCP.ResetsSent.Increment() } return nil @@ -730,7 +878,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } if e.sackPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) @@ -749,11 +897,20 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, tcpFields{ + id: e.ID, + ttl: e.ttl, + tos: e.sendTOS, + flags: flags, + seq: seq, + ack: ack, + rcvWnd: rcvWnd, + opts: options, + }, data, e.gso) putOptions(options) return err } @@ -768,7 +925,6 @@ func (e *endpoint) handleWrite() *tcpip.Error { first := e.sndQueue.Front() if first != nil { e.snd.writeList.PushBackList(&e.sndQueue) - e.snd.sndNxtList.UpdateForward(e.sndBufInQueue) e.sndBufInQueue = 0 } @@ -786,6 +942,9 @@ func (e *endpoint) handleWrite() *tcpip.Error { } func (e *endpoint) handleClose() *tcpip.Error { + if !e.EndpointState().connected() { + return nil + } // Drain the send queue. e.handleWrite() @@ -802,69 +961,194 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. - if e.state == StateEstablished || e.state == StateCloseWait { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - } - e.state = StateError + e.setEndpointState(StateError) e.HardError = err - if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { + resetSeqNum = e.snd.sndNxt + } + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) } } // completeWorkerLocked is called by the worker goroutine when it's about to -// exit. It marks the worker as completed and performs cleanup work if requested -// by Close(). +// exit. func (e *endpoint) completeWorkerLocked() { + // Worker is terminating(either due to moving to + // CLOSED or ERROR state, ensure we release all + // registrations port reservations even if the socket + // itself is not yet closed by the application. e.workerRunning = false if e.workerCleanup { e.cleanupLocked() } } -// handleSegments pulls segments from the queue and processes them. It returns -// no error if the protocol loop should continue, an error otherwise. -func (e *endpoint) handleSegments() *tcpip.Error { +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + + e.setEndpointState(StateEstablished) +} + +// transitionToStateCloseLocked ensures that the endpoint is +// cleaned up from the transport demuxer, "before" moving to +// StateClose. This will ensure that no packet will be +// delivered to this endpoint from the demuxer when the endpoint +// is transitioned to StateClose. +func (e *endpoint) transitionToStateCloseLocked() { + s := e.EndpointState() + if s == StateClose { + return + } + + if s.connected() { + e.stack.Stats().TCP.CurrentConnected.Decrement() + e.stack.Stats().TCP.EstablishedClosed.Increment() + } + + // Mark the endpoint as fully closed for reads/writes. + e.cleanupLocked() + e.setEndpointState(StateClose) +} + +// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed +// segment to any other endpoint other than the current one. This is called +// only when the endpoint is in StateClose and we want to deliver the segment +// to any other listening endpoint. We reply with RST if we cannot find one. +func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { + // Dual-stack socket, try IPv4. + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + } + if ep == nil { + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + s.decRef() + return + } + + if e == ep { + panic("current endpoint not removed from demuxer, enqueing segments to itself") + } + + if ep := ep.(*endpoint); ep.enqueueSegment(s) { + ep.newSegmentWaker.Assert() + } +} + +// Drain segment queue from the endpoint and try to re-match the segment to a +// different endpoint. This is used when the current endpoint is transitioned to +// StateClose and has been unregistered from the transport demuxer. +func (e *endpoint) drainClosingSegmentQueue() { + for { + s := e.segmentQueue.dequeue() + if s == nil { + break + } + + e.tryDeliverSegmentFromClosedEndpoint(s) + } +} + +func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + switch e.EndpointState() { + // In case of a RST in CLOSE-WAIT linux moves + // the socket to closed state with an error set + // to indicate EPIPE. + // + // Technically this seems to be at odds w/ RFC. + // As per https://tools.ietf.org/html/rfc793#section-2.7 + // page 69 the behavior for a segment arriving + // w/ RST bit set in CLOSE-WAIT is inlined below. + // + // ESTABLISHED + // FIN-WAIT-1 + // FIN-WAIT-2 + // CLOSE-WAIT + + // If the RST bit is set then, any outstanding RECEIVEs and + // SEND should receive "reset" responses. All segment queues + // should be flushed. Users should also receive an unsolicited + // general "connection reset" signal. Enter the CLOSED state, + // delete the TCB, and return. + case StateCloseWait: + e.transitionToStateCloseLocked() + e.HardError = tcpip.ErrAborted + e.notifyProtocolGoroutine(notifyTickleWorker) + return false, nil + default: + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + + // Notify protocol goroutine. This is required when + // handleSegment is invoked from the processor goroutine + // rather than the worker goroutine. + e.notifyProtocolGoroutine(notifyResetByPeer) + return false, tcpip.ErrConnectionReset + } + } + return true, nil +} + +// handleSegments processes all inbound segments. +func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + if e.EndpointState().closed() { + return nil + } s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false break } - // Invoke the tcp probe if installed. - if e.probe != nil { - e.probe(e.completeState()) + cont, err := e.handleSegment(s) + if err != nil { + s.decRef() + return err } - - if s.flagIsSet(header.TCPFlagRst) { - if e.rcv.acceptable(s.sequenceNumber, 0) { - // RFC 793, page 37 states that "in all states - // except SYN-SENT, all reset (RST) segments are - // validated by checking their SEQ-fields." So - // we only process it if it's acceptable. - s.decRef() - return tcpip.ErrConnectionReset - } - } else if s.flagIsSet(header.TCPFlagAck) { - // Patch the window size in the segment according to the - // send window scale. - s.window <<= e.snd.sndWndScale - - // RFC 793, page 41 states that "once in the ESTABLISHED - // state all segments must carry current acknowledgment - // information." - e.rcv.handleRcvdSegment(s) - e.snd.handleRcvdSegment(s) + if !cont { + s.decRef() + return nil } - s.decRef() } - // If the queue is not empty, make sure we'll wake up in the next - // iteration. - if checkRequeue && !e.segmentQueue.empty() { + // When fastPath is true we don't want to wake up the worker + // goroutine. If the endpoint has more segments to process the + // dispatcher will call handleSegments again anyway. + if !fastPath && checkRequeue && !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } @@ -873,23 +1157,114 @@ func (e *endpoint) handleSegments() *tcpip.Error { e.snd.sendAck() } - e.resetKeepaliveTimer(true) + e.resetKeepaliveTimer(true /* receivedData */) return nil } +func (e *endpoint) probeSegment() { + if e.probe != nil { + e.probe(e.completeState()) + } +} + +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. The tcp probe function will update + // the TCPEndpointState after the segment is processed. + defer e.probeSegment() + + if s.flagIsSet(header.TCPFlagRst) { + if ok, err := e.handleReset(s); !ok { + return false, err + } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() + } else if s.flagIsSet(header.TCPFlagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + return false, err + } + if drop { + return true, nil + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + state := e.state + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return false, nil + } + + e.snd.handleRcvdSegment(s) + } + + return true, nil +} + // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + userTimeout := e.userTimeout + e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } @@ -906,7 +1281,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // whether it is enabled for this endpoint. func (e *endpoint) resetKeepaliveTimer(receivedData bool) { e.keepalive.Lock() - defer e.keepalive.Unlock() if receivedData { e.keepalive.unacked = 0 } @@ -914,6 +1288,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { // data to send. if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() + e.keepalive.Unlock() return } if e.keepalive.unacked > 0 { @@ -921,6 +1296,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } else { e.keepalive.timer.enable(e.keepalive.idle) } + e.keepalive.Unlock() } // disableKeepaliveTimer stops the keepalive timer. @@ -933,7 +1309,8 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { + e.mu.Lock() var closeTimer *time.Timer var closeWaker sleep.Waker @@ -956,6 +1333,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.mu.Unlock() + e.drainClosingSegmentQueue() + // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -966,61 +1345,32 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // completion. initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) - e.mu.Lock() - h.ep.state = StateSynSent - e.mu.Unlock() + h.ep.setEndpointState(StateSynSent) if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() - e.mu.Lock() - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - e.state = StateError + e.setEndpointState(StateError) e.HardError = err + e.workerCleanup = true // Lock released below. epilogue() - return err } - - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // boot strap the auto tuning algorithm. Starting at zero will - // result in a large step function on the first proper causing - // the window to just go to a really large value after the first - // RTT itself. - e.rcvAutoParams.prevCopied = initialRcvWnd - e.rcvListMu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - // Tell waiters that the endpoint is connected and writable. - e.mu.Lock() - if e.state != StateEstablished { - e.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished - } drained := e.drainDone != nil - e.mu.Unlock() if drained { close(e.drainDone) <-e.undrain } - e.waiterQueue.Notify(waiter.EventOut) - // Set up the functions that will be called when the main protocol loop // wakes up. funcs := []struct { @@ -1036,25 +1386,33 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { f: e.handleClose, }, { - w: &e.newSegmentWaker, - f: e.handleSegments, - }, - { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP-FIN-WAIT2 timeout was hit. Just + // mark the socket as closed. + e.transitionToStateCloseLocked() + e.workerCleanup = true + return nil }, }, { w: &e.snd.resendWaker, f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } return nil }, }, { + w: &e.newSegmentWaker, + f: func() *tcpip.Error { + return e.handleSegments(false /* fastPath */) + }, + }, + { w: &e.keepalive.waker, f: e.keepaliveTimerExpired, }, @@ -1080,22 +1438,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.snd.updateMaxPayloadSize(mtu, count) } - if n¬ifyReset != 0 { - e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.mu.Unlock() + if n¬ifyReset != 0 || n¬ifyAbort != 0 { + return tcpip.ErrConnectionAborted + } + + if n¬ifyResetByPeer != 0 { + return tcpip.ErrConnectionReset } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + if e.EndpointState() == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } } if n¬ifyKeepaliveChanged != 0 { @@ -1107,16 +1464,26 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(); err != nil { + if err := e.handleSegments(false /* fastPath */); err != nil { return err } } - if e.state != StateError { + if !e.EndpointState().closed() { + // Only block the worker if the endpoint + // is not in closed state or error state. close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1128,14 +1495,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.AddWaker(funcs[i].w, i) } + // Notify the caller that the waker initialization is complete and the + // endpoint is ready. + if wakerInitDone != nil { + close(wakerInitDone) + } + + // Tell waiters that the endpoint is connected and writable. + e.waiterQueue.Notify(waiter.EventOut) + // The following assertions and notifications are needed for restored // endpoints. Fresh newly created endpoints have empty states and should // not invoke any. - e.segmentQueue.mu.Lock() - if !e.segmentQueue.list.Empty() { + if !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } - e.segmentQueue.mu.Unlock() e.rcvListMu.Lock() if !e.rcvList.Empty() { @@ -1143,41 +1517,209 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { - e.workMu.Unlock() - v, _ := s.Fetch(true) - e.workMu.Lock() - if err := funcs[v].f(); err != nil { - e.mu.Lock() - // Ensure we release all endpoint registration and route - // references as the connection is now in an error - // state. - e.workerCleanup = true + cleanupOnError := func(err *tcpip.Error) { + e.stack.Stats().TCP.CurrentConnected.Decrement() + e.workerCleanup = true + if err != nil { e.resetConnectionLocked(err) - // Lock released below. - epilogue() + } + // Lock released below. + epilogue() + } +loop: + for { + switch e.EndpointState() { + case StateTimeWait, StateClose, StateError: + break loop + } + + e.mu.Unlock() + v, _ := s.Fetch(true) + e.mu.Lock() + + // We need to double check here because the notification may be + // stale by the time we got around to processing it. + switch e.EndpointState() { + case StateError: + // If the endpoint has already transitioned to an ERROR + // state just pass nil here as any reset that may need + // to be sent etc should already have been done and we + // just want to terminate the loop and cleanup the + // endpoint. + cleanupOnError(nil) return nil + case StateTimeWait: + fallthrough + case StateClose: + break loop + default: + if err := funcs[v].f(); err != nil { + cleanupOnError(err) + return nil + } } } - // Mark endpoint as closed. - e.mu.Lock() - if e.state != StateError { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - e.state = StateClose + var reuseTW func() + if e.EndpointState() == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.workerCleanup = true + reuseTW = e.doTimeWait() + } + + // Handle any StateError transition from StateTimeWait. + if e.EndpointState() == StateError { + cleanupOnError(nil) + return nil } + + e.transitionToStateCloseLocked() + // Lock released below. epilogue() + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + if !tcpEP.enqueueSegment(s) { + s.decRef() + return + } + tcpEP.newSegmentWaker.Assert() + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + defer s.Done() + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.mu.Unlock() + v, _ := s.Fetch(true) + e.mu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyAbort != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + e.mu.Unlock() + <-e.undrain + e.mu.Lock() + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go new file mode 100644 index 000000000..cfc304616 --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect_unsafe.go @@ -0,0 +1,30 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "reflect" + "unsafe" +) + +// optionsToArray converts a slice of capacity >-= maxOptionSize to an array. +// +// optionsToArray panics if the capacity of options is smaller than +// maxOptionSize. +func optionsToArray(options []byte) *[maxOptionSize]byte { + // Reslice to full capacity. + options = options[0:maxOptionSize] + return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data)) +} diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go new file mode 100644 index 000000000..98aecab9e --- /dev/null +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -0,0 +1,234 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// epQueue is a queue of endpoints. +type epQueue struct { + mu sync.Mutex + list endpointList +} + +// enqueue adds e to the queue if the endpoint is not already on the queue. +func (q *epQueue) enqueue(e *endpoint) { + q.mu.Lock() + if e.pendingProcessing { + q.mu.Unlock() + return + } + q.list.PushBack(e) + e.pendingProcessing = true + q.mu.Unlock() +} + +// dequeue removes and returns the first element from the queue if available, +// returns nil otherwise. +func (q *epQueue) dequeue() *endpoint { + q.mu.Lock() + if e := q.list.Front(); e != nil { + q.list.Remove(e) + e.pendingProcessing = false + q.mu.Unlock() + return e + } + q.mu.Unlock() + return nil +} + +// empty returns true if the queue is empty, false otherwise. +func (q *epQueue) empty() bool { + q.mu.Lock() + v := q.list.Empty() + q.mu.Unlock() + return v +} + +// processor is responsible for processing packets queued to a tcp endpoint. +type processor struct { + epQ epQueue + sleeper sleep.Sleeper + newEndpointWaker sleep.Waker + closeWaker sleep.Waker +} + +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) queueEndpoint(ep *endpoint) { + // Queue an endpoint for processing by the processor goroutine. + p.epQ.enqueue(ep) + p.newEndpointWaker.Assert() +} + +const ( + newEndpointWaker = 1 + closeWaker = 2 +) + +func (p *processor) start(wg *sync.WaitGroup) { + defer wg.Done() + defer p.sleeper.Done() + + for { + if id, _ := p.sleeper.Fetch(true); id == closeWaker { + break + } + for { + ep := p.epQ.dequeue() + if ep == nil { + break + } + if ep.segmentQueue.empty() { + continue + } + + // If socket has transitioned out of connected state then just let the + // worker handle the packet. + // + // NOTE: We read this outside of e.mu lock which means that by the time + // we get to handleSegments the endpoint may not be in ESTABLISHED. But + // this should be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a CLOSED/ERROR state + // then handleSegments is a noop. + if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { + // If the endpoint is in a connected state then we do direct delivery + // to ensure low latency and avoid scheduler interactions. + switch err := ep.handleSegments(true /* fastPath */); { + case err != nil: + // Send any active resets if required. + ep.resetConnectionLocked(err) + fallthrough + case ep.EndpointState() == StateClose: + ep.notifyProtocolGoroutine(notifyTickleWorker) + case !ep.segmentQueue.empty(): + p.epQ.enqueue(ep) + } + ep.mu.Unlock() + } else { + ep.newSegmentWaker.Assert() + } + } + } +} + +// dispatcher manages a pool of TCP endpoint processors which are responsible +// for the processing of inbound segments. This fixed pool of processor +// goroutines do full tcp processing. The processor is selected based on the +// hash of the endpoint id to ensure that delivery for the same endpoint happens +// in-order. +type dispatcher struct { + processors []processor + seed uint32 + wg sync.WaitGroup +} + +func (d *dispatcher) init(nProcessors int) { + d.close() + d.wait() + d.processors = make([]processor, nProcessors) + d.seed = generateRandUint32() + for i := range d.processors { + p := &d.processors[i] + p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) + p.sleeper.AddWaker(&p.closeWaker, closeWaker) + d.wg.Add(1) + // NB: sleeper-waker registration must happen synchronously to avoid races + // with `close`. It's possible to pull all this logic into `start`, but + // that results in a heap-allocated function literal. + go p.start(&d.wg) + } +} + +func (d *dispatcher) close() { + for i := range d.processors { + d.processors[i].close() + } +} + +func (d *dispatcher) wait() { + d.wg.Wait() +} + +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + ep := stackEP.(*endpoint) + s := newSegment(r, id, pkt) + if !s.parse() { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + s.decRef() + return + } + + if !s.csumValid { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.ChecksumErrors.Increment() + ep.stats.ReceiveErrors.ChecksumErrors.Increment() + s.decRef() + return + } + + ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() + ep.stats.SegmentsReceived.Increment() + if (s.flags & header.TCPFlagRst) != 0 { + ep.stack.Stats().TCP.ResetsReceived.Increment() + } + + if !ep.enqueueSegment(s) { + s.decRef() + return + } + + // For sockets not in established state let the worker goroutine + // handle the packets. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + return + } + + d.selectProcessor(id).queueEndpoint(ep) +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return binary.LittleEndian.Uint32(b) +} + +func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { + var payload [4]byte + binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) + binary.LittleEndian.PutUint16(payload[2:], id.RemotePort) + + h := jenkins.Sum32(d.seed) + h.Write(payload[:]) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + + return &d.processors[h.Sum32()%uint32(len(d.processors))] +} diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index dfaa4a559..804e95aea 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) { // Make sure we get the same error when calling the original ep and the // new one. This validates that v4-mapped endpoints are still able to // query the V6Only flag, whereas pure v4 endpoints are not. - var v tcpip.V6OnlyOption - expected := c.EP.GetSockOpt(&v) - if err := nep.GetSockOpt(&v); err != expected { + _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) } @@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) { // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. - var v tcpip.V6OnlyOption - if err := nep.GetSockOpt(&v); err != nil { + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { t.Fatalf("GetSockOpt failed failed: %v", err) } @@ -570,11 +568,10 @@ func TestV4AcceptOnV4(t *testing.T) { func testV4ListenClose(t *testing.T, c *context.Context) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + } + const n = uint16(32) // Start listening. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6ca0d73a9..1ccedebcc 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -18,21 +18,21 @@ import ( "encoding/binary" "fmt" "math" + "runtime" "strings" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tmutex" "gvisor.dev/gvisor/pkg/waiter" ) @@ -63,7 +63,8 @@ const ( StateClosing ) -// connected is the set of states where an endpoint is connected to a peer. +// connected returns true when s is one of the states representing an +// endpoint connected to a peer. func (s EndpointState) connected() bool { switch s { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: @@ -73,6 +74,40 @@ func (s EndpointState) connected() bool { } } +// connecting returns true when s is one of the states representing a +// connection in progress, but not yet fully established. +func (s EndpointState) connecting() bool { + switch s { + case StateConnecting, StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// handshake returns true when s is one of the states representing an endpoint +// in the middle of a TCP handshake. +func (s EndpointState) handshake() bool { + switch s { + case StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// closed returns true when s is one of the states an endpoint transitions to +// when closed or when it encounters an error. This is distinct from a newly +// initialized endpoint that was never connected. +func (s EndpointState) closed() bool { + switch s { + case StateClose, StateError: + return true + default: + return false + } +} + // String implements fmt.Stringer.String. func (s EndpointState) String() string { switch s { @@ -119,8 +154,17 @@ const ( notifyMTUChanged notifyDrain notifyReset + notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker + notifyError ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -273,20 +317,59 @@ func (*EndpointInfo) IsEndpointInfo() {} // synchronized. The protocol implementation, however, runs in a single // goroutine. // +// Each endpoint has a few mutexes: +// +// e.mu -> Primary mutex for an endpoint must be held for all operations except +// in e.Readiness where acquiring it will result in a deadlock in epoll +// implementation. +// +// The following three mutexes can be acquired independent of e.mu but if +// acquired with e.mu then e.mu must be acquired first. +// +// e.acceptMu -> protects acceptedChan. +// e.rcvListMu -> Protects the rcvList and associated fields. +// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.lastErrorMu -> Protects the lastError field. +// +// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different +// based on the context in which the lock is acquired. In the syscall context +// e.LockUser/e.UnlockUser should be used and when doing background processing +// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below +// in brief. +// +// The reason for this locking behaviour is to avoid wakeups to handle packets. +// In cases where the endpoint is already locked the background processor can +// queue the packet up and go its merry way and the lock owner will eventually +// process the backlog when releasing the lock. Similarly when acquiring the +// lock from say a syscall goroutine we can implement a bit of spinning if we +// know that the lock is not held by another syscall goroutine. Background +// processors should never hold the lock for long and we can avoid an expensive +// sleep/wakeup by spinning for a shortwhile. +// +// For more details please see the detailed documentation on +// e.LockUser/e.UnlockUser methods. +// // +stateify savable type endpoint struct { EndpointInfo - // workMu is used to arbitrate which goroutine may perform protocol - // work. Only the main protocol goroutine is expected to call Lock() on - // it, but other goroutines (e.g., send) may call TryLock() to eagerly - // perform work without having to wait for the main one to wake up. - workMu tmutex.Mutex `state:"nosave"` + // endpointEntry is used to queue endpoints for processing to the + // a given tcp processor goroutine. + // + // Precondition: epQueue.mu must be held to read/write this field.. + endpointEntry `state:"nosave"` + + // pendingProcessing is true if this endpoint is queued for processing + // to a TCP processor. + // + // Precondition: epQueue.mu must be held to read/write this field.. + pendingProcessing bool `state:"nosave"` // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -307,21 +390,24 @@ type endpoint struct { rcvBufSize int rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams - // zeroWindow indicates that the window was closed due to receive buffer - // space being filled up. This is set by the worker goroutine before - // moving a segment to the rcvList. This setting is cleared by the - // endpoint when a Read() call reads enough data for the new window to - // be non-zero. - zeroWindow bool - // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` + // mu protects all endpoint fields unless documented otherwise. mu must + // be acquired before interacting with the endpoint fields. + mu sync.Mutex `state:"nosave"` + ownedByUser uint32 + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` - isRegistered bool - boundNICID tcpip.NICID `state:"manual"` + isRegistered bool `state:"manual"` + boundNICID tcpip.NICID route stack.Route `state:"manual"` ttl uint8 v6only bool @@ -330,19 +416,28 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // portFlags stores the current values of port related flags. + portFlags ports.Flags + + // Values used to reserve a port or register a transport endpoint + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + boundDest tcpip.FullAddress + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). - effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` + effectiveNetProtos []tcpip.NetworkProtocolNumber // workerRunning specifies if a worker goroutine is running. workerRunning bool // workerCleanup specifies if the worker goroutine must perform cleanup - // before exitting. This can only be set to true when workerRunning is + // before exiting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. workerCleanup bool @@ -356,6 +451,9 @@ type endpoint struct { // updated if required when a new segment is received by this endpoint. recentTS uint32 + // recentTSTime is the unix time when we updated recentTS last. + recentTSTime time.Time `state:".(unixTime)"` + // tsOffset is a randomized offset added to the value of the // TSVal field in the timestamp option. tsOffset uint32 @@ -370,9 +468,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // reusePort is set to true if SO_REUSEPORT is enabled. - reusePort bool - // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -392,7 +487,6 @@ type endpoint struct { // The options below aren't implemented, but we remember the user // settings because applications expect to be able to set/query these // options. - reuseAddr bool // slowAck holds the negated state of quick ack. It is stubbed out and // does nothing. @@ -411,7 +505,18 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 + + // maxSynRetries is the maximum number of SYN retransmits that TCP should + // send before aborting the attempt to connect. It cannot exceed 255. + // + // NOTE: This is currently a no-op and does not change the SYN + // retransmissions. + maxSynRetries uint8 + + // windowClamp is used to bound the size of the advertised window to + // this value. + windowClamp uint32 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -458,12 +563,42 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + + // deferAccept if non-zero specifies a user specified time during + // which the final ACK of a handshake will be dropped provided the + // ACK is a bare ACK and carries no data. If the timeout is crossed then + // the bare ACK is accepted and the connection is delivered to the + // listener. + deferAccept time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing // to the acceptedChan below terminate before we close acceptedChan. pendingAccepted sync.WaitGroup `state:"nosave"` + // acceptMu protects acceptedChan. + acceptMu sync.Mutex `state:"nosave"` + + // acceptCond is a condition variable that can be used to block on when + // acceptedChan is full and an endpoint is ready to be delivered. + // + // This condition variable is required because just blocking on sending + // to acceptedChan does not work in cases where endpoint.Listen is + // called twice with different backlog values. In such cases the channel + // is closed and a new one created. Any pending goroutines blocking on + // the write to the channel will panic. + // + // We use this condition variable to block/unblock goroutines which + // tried to deliver an endpoint but couldn't because accept backlog was + // full ( See: endpoint.deliverAccepted ). + acceptCond *sync.Cond `state:"nosave"` + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -502,16 +637,175 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool + + // txHash is the transport layer hash to be set on outbound packets + // emitted by this endpoint. + txHash uint32 + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + // TODO(b/143359391): Respect TCP Min and Max size. + maxMSS := uint16(r.MTU() - header.TCPMinimumSize) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS +} + +// LockUser tries to lock e.mu and if it fails it will check if the lock is held +// by another syscall goroutine. If yes, then it will goto sleep waiting for the +// lock to be released, if not then it will spin till it acquires the lock or +// another syscall goroutine acquires it in which case it will goto sleep as +// described above. +// +// The assumption behind spinning here being that background packet processing +// should not be holding the lock for long and spinning reduces latency as we +// avoid an expensive sleep/wakeup of of the syscall goroutine). +func (e *endpoint) LockUser() { + for { + // Try first if the sock is locked then check if it's owned + // by another user goroutine if not then we spin, otherwise + // we just goto sleep on the Lock() and wait. + if !e.mu.TryLock() { + // If socket is owned by the user then just goto sleep + // as the lock could be held for a reasonably long time. + if atomic.LoadUint32(&e.ownedByUser) == 1 { + e.mu.Lock() + atomic.StoreUint32(&e.ownedByUser, 1) + return + } + // Spin but yield the processor since the lower half + // should yield the lock soon. + runtime.Gosched() + continue + } + atomic.StoreUint32(&e.ownedByUser, 1) + return + } +} + +// UnlockUser will check if there are any segments already queued for processing +// and process any such segments before unlocking e.mu. This is required because +// we when packets arrive and endpoint lock is already held then such packets +// are queued up to be processed. If the lock is held by the endpoint goroutine +// then it will process these packets but if the lock is instead held by the +// syscall goroutine then we can have the syscall goroutine process the backlog +// before unlocking. +// +// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the +// endpoint. It's also required eventually when we get rid of the endpoint +// protocol goroutine altogether. +// +// Precondition: e.LockUser() must have been called before calling e.UnlockUser() +func (e *endpoint) UnlockUser() { + // Lock segment queue before checking so that we avoid a race where + // segments can be queued between the time we check if queue is empty + // and actually unlock the endpoint mutex. + for { + e.segmentQueue.mu.Lock() + if e.segmentQueue.emptyLocked() { + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + e.segmentQueue.mu.Unlock() + return + } + e.segmentQueue.mu.Unlock() + + switch e.EndpointState() { + case StateEstablished: + if err := e.handleSegments(true /* fastPath */); err != nil { + e.notifyProtocolGoroutine(notifyTickleWorker) + } + default: + // Since we are waking the endpoint goroutine here just unlock + // and let it process the queued segments. + e.newSegmentWaker.Assert() + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + return + } + } } // StopWork halts packet processing. Only to be used in tests. func (e *endpoint) StopWork() { - e.workMu.Lock() + e.mu.Lock() } // ResumeWork resumes packet processing. Only to be used in tests. func (e *endpoint) ResumeWork() { - e.workMu.Unlock() + e.mu.Unlock() +} + +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + switch state { + case StateEstablished: + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.stack.Stats().TCP.CurrentConnected.Increment() + case StateError: + fallthrough + case StateClose: + if oldstate == StateCloseWait || oldstate == StateEstablished { + e.stack.Stats().TCP.EstablishedResets.Increment() + } + fallthrough + default: + if oldstate == StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } + } + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + +// setRecentTimestamp sets the recentTS field to the provided value. +func (e *endpoint) setRecentTimestamp(recentTS uint32) { + e.recentTS = recentTS + e.recentTSTime = time.Now() +} + +// recentTimestamp returns the value of the recentTS field. +func (e *endpoint) recentTimestamp() uint32 { + return e.recentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -543,13 +837,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSize: DefaultReceiveBufferSize, sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), - reuseAddr: true, keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), + txHash: s.Rand().Uint32(), + windowClamp: DefaultReceiveBufferSize, + maxSynRetries: DefaultSynRetries, } var ss SendBufferSizeOption @@ -572,14 +869,28 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptBool(tcpip.DelayOption, true) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + + var synRetries tcpip.TCPSynRetriesOption + if err := s.TransportProtocolOption(ProtocolNumber, &synRetries); err == nil { + e.maxSynRetries = uint8(synRetries) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - e.workMu.Lock() e.tsOffset = timeStampOffset() + e.acceptCond = sync.NewCond(&e.acceptMu) return e } @@ -589,26 +900,25 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result := waiter.EventMask(0) - e.mu.RLock() - defer e.mu.RUnlock() - - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case StateClose, StateError: + case StateClose, StateError, StateTimeWait: // Ready for anything. result = mask case StateListen: // Check if there's anything in the accepted channel. if (mask & waiter.EventIn) != 0 { + e.acceptMu.Lock() if len(e.acceptedChan) > 0 { result |= waiter.EventIn } + e.acceptMu.Unlock() } } - if e.state.connected() { + if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -655,69 +965,117 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + // + // If the endpoint connecting finishes after the check, the endpoint + // is either in a connected state (where we would notifyAbort anyway), + // SYN-RECV (where we would also notifyAbort anyway), or in an error + // state where nothing is required and the notification can be safely + // ignored. + // + // Endpoints where a Close during connecting or SYN-RECV state would be + // problematic are set to state connecting before being registered (and + // thus possible to be Aborted). They are never available in initial + // state. + // + // Endpoints transitioning from initial to connecting state may be + // safely either closed or sent notifyAbort. + if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.LockUser() + defer e.UnlockUser() + if e.closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. - e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - - e.mu.Lock() + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdownLocked() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdownLocked() { // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == StateListen && e.isPortReserved { + if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} + } + + // Mark endpoint as closed. + e.closed = true + + switch e.EndpointState() { + case StateClose, StateError: + return } // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. - tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { - e.cleanupLocked() - } else { + if e.workerRunning { e.workerCleanup = true + tcpip.AddDanglingEndpoint(e) + // Worker will remove the dangling endpoint when the endpoint + // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) + } else { + e.transitionToStateCloseLocked() } - - e.mu.Unlock() } // closePendingAcceptableConnections closes all connections that have completed // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { - done := make(chan struct{}) - // Spin a goroutine up as ranging on e.acceptedChan will just block when - // there are no more connections in the channel. Using a non-blocking - // select does not work as it can potentially select the default case - // even when there are pending writes but that are not yet written to - // the channel. - go func() { - defer close(done) - for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() - n.Close() - } - }() - // pendingAccepted(see endpoint.deliverAccepted) tracks the number of - // endpoints which have completed handshake but are not yet written to - // the e.acceptedChan. We wait here till the goroutine above can drain - // all such connections from e.acceptedChan. - e.pendingAccepted.Wait() + e.acceptMu.Lock() + if e.acceptedChan == nil { + e.acceptMu.Unlock() + return + } close(e.acceptedChan) - <-done + ch := e.acceptedChan e.acceptedChan = nil + e.acceptCond.Broadcast() + e.acceptMu.Unlock() + + // Reset all connections that are waiting to be accepted. + for n := range ch { + n.notifyProtocolGoroutine(notifyReset) + } + // Wait for reset of all endpoints that are still waiting to be delivered to + // the now closed acceptedChan. + e.pendingAccepted.Wait() } // cleanupLocked frees all resources associated with the endpoint. It is called @@ -726,22 +1084,25 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. - if e.acceptedChan != nil { - e.closePendingAcceptableConnectionsLocked() - } + e.closePendingAcceptableConnectionsLocked() + e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false } + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -752,16 +1113,34 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } + rcvWndScale := e.rcvWndScaleForHandshake() + + // Round-down the rcvWnd to a multiple of wndScale. This ensures that the + // window offered in SYN won't be reduced due to the loss of precision if + // window scaling is enabled after the handshake. + rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale) + + // Ensure we can always accept at least 1 byte if the scale specified + // was too high for the provided rcvWnd. + if rcvWnd == 0 { + rcvWnd = 1 + } + return rcvWnd } // ModerateRecvBuf adjusts the receive buffer and the advertised window -// based on the number of bytes copied to user space. +// based on the number of bytes copied to userspace. func (e *endpoint) ModerateRecvBuf(copied int) { + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() @@ -807,8 +1186,14 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { + availBefore := e.receiveBufferAvailableLocked() e.rcvBufSize = rcvWnd - e.notifyProtocolGoroutine(notifyReceiveWindowChanged) + availAfter := e.receiveBufferAvailableLocked() + mask := uint32(notifyReceiveWindowChanged) + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + mask |= notifyNonZeroReceiveWindow + } + e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -822,36 +1207,50 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { - return e.stack.IPTables(), nil +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + err := e.lastError + e.lastError = nil + return err } // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() + e.LockUser() + defer e.UnlockUser() + + // When in SYN-SENT state, let the caller block on the receive. + // An application can initiate a non-blocking connect and then block + // on a receive. It can expect to read any data after the handshake + // is complete. RFC793, section 3.9, p58. + if e.EndpointState() == StateSynSent { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { + if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError - e.mu.RUnlock() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } - e.stats.ReadErrors.InvalidEndpointState.Increment() - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + e.stats.ReadErrors.NotConnected.Increment() + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected } v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() - if err == tcpip.ErrClosedForReceive { e.stats.ReadErrors.ReadClosed.Increment() } @@ -860,7 +1259,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -877,11 +1276,12 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { } e.rcvBufUsed -= len(v) - // If the window was zero before this read and if the read freed up - // enough buffer space for the scaled window to be non-zero then notify - // the protocol goroutine to send a window update. - if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) { - e.zeroWindow = false + + // If the window was small before this read and if the read freed up + // enough buffer space, to either fit an aMSS or half a receive buffer + // (whichever smaller), then notify the protocol goroutine to send a + // window update. + if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -895,8 +1295,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. - if !e.state.connected() { - switch e.state { + if !e.EndpointState().connected() { + switch e.EndpointState() { case StateError: return 0, e.HardError default: @@ -922,13 +1322,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. - e.mu.RLock() + e.LockUser() e.sndBufMu.Lock() avail, err := e.isEndpointWritableLocked() if err != nil { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -940,73 +1340,72 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // are copying data in. if !opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } // Fetch data. v, perr := p.Payload(avail) if perr != nil || len(v) == 0 { - if opts.Atomic { // See above. + // Note that perr may be nil if len(v) == 0. + if opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } - // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } + // Do the work inline. + e.handleWrite() + e.UnlockUser() + return int64(len(v)), nil, nil + } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } + if opts.Atomic { + // Locks released in queueAndSend() + return queueAndSend() } - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + e.LockUser() + e.sndBufMu.Lock() + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } - if e.workMu.TryLock() { - // Do the work inline. - e.handleWrite() - e.workMu.Unlock() - } else { - // Let the protocol goroutine do the work. - e.sndWaker.Assert() + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } - return int64(len(v)), nil, nil + // Locks released in queueAndSend() + return queueAndSend() } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; !s.connected() && s != StateClose { + if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { return 0, tcpip.ControlMessages{}, e.HardError } @@ -1018,7 +1417,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } @@ -1055,37 +1454,174 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } -// zeroReceiveWindow checks if the receive window to be announced now would be -// zero, based on the amount of available buffer and the receive window scaling. +// windowCrossedACKThresholdLocked checks if the receive window to be announced +// now would be under aMSS or under half receive buffer, whichever smaller. This +// is useful as a receive side silly window syndrome prevention mechanism. If +// window grows to reasonable value, we should send ACK to the sender to inform +// the rx space is now large. We also want ensure a series of small read()'s +// won't trigger a flood of spurious tiny ACK's. // -// It must be called with rcvListMu held. -func (e *endpoint) zeroReceiveWindow(scale uint8) bool { - if e.rcvBufUsed >= e.rcvBufSize { - return true +// For large receive buffers, the threshold is aMSS - once reader reads more +// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of +// receive buffer size. This is chosen arbitrairly. +// crossed will be true if the window size crossed the ACK threshold. +// above will be true if the new window is >= ACK threshold and false +// otherwise. +// +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { + newAvail := e.receiveBufferAvailableLocked() + oldAvail := newAvail - deltaBefore + if oldAvail < 0 { + oldAvail = 0 + } + + threshold := int(e.amss) + if threshold > e.rcvBufSize/2 { + threshold = e.rcvBufSize / 2 + } + + switch { + case oldAvail < threshold && newAvail >= threshold: + return true, true + case oldAvail >= threshold && newAvail < threshold: + return true, false + } + return false, false +} + +// SetSockOptBool sets a socket option. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + + case tcpip.BroadcastOption: + e.LockUser() + e.broadcast = v + e.UnlockUser() + + case tcpip.CorkOption: + e.LockUser() + if !v { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + e.UnlockUser() + + case tcpip.DelayOption: + if v { + atomic.StoreUint32(&e.delay, 1) + } else { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.QuickAckOption: + o := uint32(1) + if v { + o = 0 + } + atomic.StoreUint32(&e.slowAck, o) + + case tcpip.ReuseAddressOption: + e.LockUser() + e.portFlags.TupleOnly = v + e.UnlockUser() + + case tcpip.ReusePortOption: + e.LockUser() + e.portFlags.LoadBalanced = v + e.UnlockUser() + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + // We only allow this to be set when we're in the initial state. + if e.EndpointState() != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.LockUser() + e.v6only = v + e.UnlockUser() } - return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 + return nil } // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.IPv4TOSOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.userMSS = uint16(userMSS) + e.UnlockUser() + e.notifyProtocolGoroutine(notifyMSSChanged) + + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. var rs ReceiveBufferSizeOption - size := int(v) if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if size < rs.Min { - size = rs.Min + if v < rs.Min { + v = rs.Min } - if size > rs.Max { - size = rs.Max + if v > rs.Max { + v = rs.Max } } mask := uint32(notifyReceiveWindowChanged) + e.LockUser() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1094,179 +1630,119 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { if e.rcv != nil { scale = e.rcv.rcvWndScale } - if size>>scale == 0 { - size = 1 << scale + if v>>scale == 0 { + v = 1 << scale } // Make sure 2*size doesn't overflow. - if size > math.MaxInt32/2 { - size = math.MaxInt32 / 2 + if v > math.MaxInt32/2 { + v = math.MaxInt32 / 2 } - e.rcvBufSize = size + availBefore := e.receiveBufferAvailableLocked() + e.rcvBufSize = v + availAfter := e.receiveBufferAvailableLocked() + e.rcvAutoParams.disabled = true - if e.zeroWindow && !e.zeroReceiveWindow(scale) { - e.zeroWindow = false + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } - e.rcvListMu.Unlock() + e.rcvListMu.Unlock() + e.UnlockUser() e.notifyProtocolGoroutine(mask) - return nil case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - size := int(v) var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if size < ss.Min { - size = ss.Min + if v < ss.Min { + v = ss.Min } - if size > ss.Max { - size = ss.Max + if v > ss.Max { + v = ss.Max } } e.sndBufMu.Lock() - e.sndBufSize = size + e.sndBufSize = v e.sndBufMu.Unlock() - return nil - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) + case tcpip.TTLOption: + e.LockUser() + e.ttl = uint8(v) + e.UnlockUser() - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) + case tcpip.TCPSynCountOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue } - return nil + e.LockUser() + e.maxSynRetries = uint8(v) + e.UnlockUser() - default: - return nil + case tcpip.TCPWindowClampOption: + if v == 0 { + e.LockUser() + switch e.EndpointState() { + case StateClose, StateInitial: + e.windowClamp = 0 + e.UnlockUser() + return nil + default: + e.UnlockUser() + return tcpip.ErrInvalidOptionValue + } + } + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if v < rs.Min/2 { + v = rs.Min / 2 + } + } + e.LockUser() + e.windowClamp = uint32(v) + e.UnlockUser() } + return nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 switch v := opt.(type) { - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.reuseAddr = v != 0 - e.mu.Unlock() - return nil - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - return nil - case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - for nicid, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicid - return nil - } - } - return tcpip.ErrUnknownDevice - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.mu.Lock() - e.userMSS = int(userMSS) - e.mu.Unlock() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v != 0 - return nil - - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - return nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v != 0 - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil + e.LockUser() + e.bindToDevice = id + e.UnlockUser() case tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIntervalOption: e.keepalive.Lock() e.keepalive.interval = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil - case tcpip.KeepaliveCountOption: - e.keepalive.Lock() - e.keepalive.count = int(v) - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil + case tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v != 0 - e.mu.Unlock() - return nil + case tcpip.TCPUserTimeoutOption: + e.LockUser() + e.userTimeout = time.Duration(v) + e.UnlockUser() case tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and @@ -1279,22 +1755,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { availCC := strings.Split(string(avail), " ") for _, cc := range availCC { if v == tcpip.CongestionControlOption(cc) { - // Acquire the work mutex as we may need to - // reinitialize the congestion control state. - e.mu.Lock() - state := e.state + e.LockUser() + state := e.EndpointState() e.cc = v - e.mu.Unlock() switch state { case StateEstablished: - e.workMu.Lock() - e.mu.Lock() - if e.state == state { + if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } - e.mu.Unlock() - e.workMu.Unlock() } + e.UnlockUser() return nil } } @@ -1303,34 +1773,44 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.IPv4TOSOption: - e.mu.Lock() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() - return nil + case tcpip.TCPLingerTimeoutOption: + e.LockUser() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + // Cap it to MaxTCPLingerTimeout. + stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.UnlockUser() - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() + case tcpip.TCPDeferAcceptOption: + e.LockUser() + if time.Duration(v) > MaxRTO { + v = tcpip.TCPDeferAcceptOption(MaxRTO) + } + e.deferAccept = time.Duration(v) + e.UnlockUser() + + case tcpip.SocketDetachFilterOption: return nil default: return nil } + return nil } // readyReceiveSize returns the number of bytes ready to be received. func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint cannot be in listen state. - if e.state == StateListen { + if e.EndpointState() == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1340,9 +1820,100 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.BroadcastOption: + e.LockUser() + v := e.broadcast + e.UnlockUser() + return v, nil + + case tcpip.CorkOption: + return atomic.LoadUint32(&e.cork) != 0, nil + + case tcpip.DelayOption: + return atomic.LoadUint32(&e.delay) != 0, nil + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + return v, nil + + case tcpip.QuickAckOption: + v := atomic.LoadUint32(&e.slowAck) == 0 + return v, nil + + case tcpip.ReuseAddressOption: + e.LockUser() + v := e.portFlags.TupleOnly + e.UnlockUser() + + return v, nil + + case tcpip.ReusePortOption: + e.LockUser() + v := e.portFlags.LoadBalanced + e.UnlockUser() + + return v, nil + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.LockUser() + v := e.v6only + e.UnlockUser() + + return v, nil + + case tcpip.MulticastLoopOption: + return true, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + v := e.keepalive.count + e.keepalive.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + v := header.TCPDefaultMSS + return v, nil + + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1358,12 +1929,26 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil - case tcpip.DelayOption: - var o int - if v := atomic.LoadUint32(&e.delay); v != 0 { - o = 1 - } - return o, nil + case tcpip.TTLOption: + e.LockUser() + v := int(e.ttl) + e.UnlockUser() + return v, nil + + case tcpip.TCPSynCountOption: + e.LockUser() + v := int(e.maxSynRetries) + e.UnlockUser() + return v, nil + + case tcpip.TCPWindowClampOption: + e.LockUser() + v := int(e.windowClamp) + e.UnlockUser() + return v, nil + + case tcpip.MulticastTTLOption: + return 1, nil default: return -1, tcpip.ErrUnknownProtocolOption @@ -1374,191 +1959,84 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - e.lastErrorMu.Lock() - err := e.lastError - e.lastError = nil - e.lastErrorMu.Unlock() - return err - - case *tcpip.MaxSegOption: - // This is just stubbed out. Linux never returns the user_mss - // value as it either returns the defaultMSS or returns the - // actual current MSS. Netstack just returns the defaultMSS - // always for now. - *o = header.TCPDefaultMSS - return nil - - case *tcpip.CorkOption: - *o = 0 - if v := atomic.LoadUint32(&e.cork); v != 0 { - *o = 1 - } - return nil - - case *tcpip.ReuseAddressOption: - e.mu.RLock() - v := e.reuseAddr - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.ReusePortOption: - e.mu.RLock() - v := e.reusePort - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil + return e.takeLastError() case *tcpip.BindToDeviceOption: - e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = "" - return nil - - case *tcpip.QuickAckOption: - *o = 1 - if v := atomic.LoadUint32(&e.slowAck); v != 0 { - *o = 0 - } - return nil - - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.TTLOption: - e.mu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() - return nil + e.LockUser() + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.UnlockUser() case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} - e.mu.RLock() + e.LockUser() snd := e.snd - e.mu.RUnlock() + e.UnlockUser() if snd != nil { snd.rtt.Lock() o.RTT = snd.rtt.srtt o.RTTVar = snd.rtt.rttvar snd.rtt.Unlock() } - return nil - - case *tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) e.keepalive.Unlock() - return nil case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) e.keepalive.Unlock() - return nil - case *tcpip.KeepaliveCountOption: - e.keepalive.Lock() - *o = tcpip.KeepaliveCountOption(e.keepalive.count) - e.keepalive.Unlock() - return nil + case *tcpip.TCPUserTimeoutOption: + e.LockUser() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.UnlockUser() case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 - return nil - - case *tcpip.BroadcastOption: - e.mu.Lock() - v := e.broadcast - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.CongestionControlOption: - e.mu.Lock() + e.LockUser() *o = e.cc - e.mu.Unlock() - return nil + e.UnlockUser() - case *tcpip.IPv4TOSOption: - e.mu.RLock() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() - return nil + case *tcpip.TCPLingerTimeoutOption: + e.LockUser() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.UnlockUser() - case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() - return nil + case *tcpip.TCPDeferAcceptOption: + e.LockUser() + *o = tcpip.TCPDeferAcceptOption(e.deferAccept) + e.UnlockUser() + + case *tcpip.OriginalDestinationOption: + ipt := e.stack.IPTables() + addr, port, err := ipt.OriginalDst(e.ID) + if err != nil { + return err + } + *o = tcpip.OriginalDestinationOption{ + Addr: addr, + Port: port, + } default: return tcpip.ErrUnknownProtocolOption } + return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + if err != nil { + return tcpip.FullAddress{}, 0, err } - - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -1583,17 +2061,17 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // yet accepted by the app, they are restored without running the main goroutine // here. func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() connectingAddr := addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } - if e.state.connected() { + if e.EndpointState().connected() { // The endpoint is already connected. If caller hasn't been // notified yet, return success. if !e.isConnectNotified { @@ -1604,8 +2082,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnected } - nicid := addr.NIC - switch e.state { + nicID := addr.NIC + switch e.EndpointState() { case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. @@ -1613,11 +2091,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1636,14 +2114,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() - origID := e.ID - netProtos := []tcpip.NetworkProtocolNumber{netProto} e.ID.LocalAddress = r.LocalAddress e.ID.RemoteAddress = r.RemoteAddress @@ -1651,7 +2127,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -1666,7 +2142,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.PortSeed()) + h := jenkins.Sum32(e.stack.Seed()) h.Write([]byte(e.ID.LocalAddress)) h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) @@ -1674,44 +2150,95 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc h.Write(portBuf) portOffset := h.Sum32() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err)) + } + + reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal + if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { + switch netProto { + case header.IPv4ProtocolNumber: + reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + case header.IPv6ProtocolNumber: + reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + } + } + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } - // reusePort is false below because connect cannot reuse a port even if - // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { - return false, nil + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + if err != tcpip.ErrPortInUse || !reuse { + return false, nil + } + transEPID := e.ID + transEPID.LocalPort = p + // Check if an endpoint is registered with demuxer in TIME-WAIT and if + // we can reuse it. If we can't find a transport endpoint then we just + // skip using this port as it's possible that either an endpoint has + // bound the port but not registered with demuxer yet (no listen/connect + // done yet) or the reservation was freed between the check above and + // the FindTransportEndpoint below. But rather than retry the same port + // we just skip it and move on. + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + if transEP == nil { + // ReservePort failed but there is no registered endpoint with + // demuxer. Which indicates there is at least some endpoint that has + // bound the port. + return false, nil + } + + tcpEP := transEP.(*endpoint) + tcpEP.LockUser() + // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but + // less than 1 second has elapsed since its recentTS was updated then + // we cannot reuse the port. + if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + tcpEP.UnlockUser() + return false, nil + } + // Since the endpoint is in TIME-WAIT it should be safe to acquire its + // Lock while holding the lock for this endpoint as endpoints in + // TIME-WAIT do not acquire locks on other endpoints. + tcpEP.workerCleanup = false + tcpEP.cleanupLocked() + tcpEP.notifyProtocolGoroutine(notifyAbort) + tcpEP.UnlockUser() + // Now try and Reserve again if it fails then we skip. + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + return false, nil + } } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { - case nil: - e.ID = id - return true, nil - case tcpip.ErrPortInUse: - return false, nil - default: + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err == tcpip.ErrPortInUse { + return false, nil + } return false, err } + + // Port picking successful. Save the details of + // the selected port. + e.ID = id + e.isPortReserved = true + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + e.boundDest = addr + return true, nil }); err != nil { return err } } - // Remove the port reservation. This can happen when Bind is called - // before Connect: in such a case we don't want to hold on to - // reservations anymore. - if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) - e.isPortReserved = false - } - e.isRegistered = true - e.state = StateConnecting + e.setEndpointState(StateConnecting) e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -1730,14 +2257,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = StateEstablished - e.stack.Stats().TCP.CurrentEstablished.Increment() + e.setEndpointState(StateEstablished) } if run { e.workerRunning = true e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } return tcpip.ErrConnectStarted @@ -1751,14 +2277,17 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - e.shutdownFlags |= flags + e.LockUser() + defer e.UnlockUser() + return e.shutdownLocked(flags) +} +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { + e.shutdownFlags |= flags switch { - case e.state.connected(): + case e.EndpointState().connected(): // Close for read. - if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { + if e.shutdownFlags&tcpip.ShutdownRead != 0 { // Mark read side as closed. e.rcvListMu.Lock() e.rcvClosed = true @@ -1767,47 +2296,56 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. - if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.notifyProtocolGoroutine(notifyReset) + if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + // Wake up worker to terminate loop. + e.notifyProtocolGoroutine(notifyTickleWorker) return nil } } // Close for write. - if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { e.sndBufMu.Lock() - if e.sndClosed { // Already closed. e.sndBufMu.Unlock() - break + if e.EndpointState() == StateTimeWait { + return tcpip.ErrNotConnected + } + return nil } // Queue fin segment. s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - // Mark endpoint as closed. e.sndClosed = true - e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() + e.handleClose() } - case e.state == StateListen: - // Tell protocolListenLoop to stop. - if flags&tcpip.ShutdownRead != 0 { - e.notifyProtocolGoroutine(notifyClose) + return nil + case e.EndpointState() == StateListen: + if e.shutdownFlags&tcpip.ShutdownRead != 0 { + // Reset all connections from the accept queue and keep the + // worker running so that it can continue handling incoming + // segments by replying with RST. + // + // By not removing this endpoint from the demuxer mapping, we + // ensure that any other bind to the same port fails, as on Linux. + e.rcvListMu.Lock() + e.rcvClosed = true + e.rcvListMu.Unlock() + e.closePendingAcceptableConnectionsLocked() + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) } - + return nil default: return tcpip.ErrNotConnected } - - return nil } // Listen puts the endpoint in "listen" mode, which allows it to accept @@ -1822,104 +2360,136 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error { } func (e *endpoint) listen(backlog int) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - // Allow the backlog to be adjusted if the endpoint is not shutting down. - // When the endpoint shuts down, it sets workerCleanup to true, and from - // that point onward, acceptedChan is the responsibility of the cleanup() - // method (and should not be touched anywhere else, including here). - if e.state == StateListen && !e.workerCleanup { - // Adjust the size of the channel iff we can fix existing - // pending connections into the new one. - if len(e.acceptedChan) > backlog { - return tcpip.ErrInvalidEndpointState - } - if cap(e.acceptedChan) == backlog { - return nil - } - origChan := e.acceptedChan - e.acceptedChan = make(chan *endpoint, backlog) - close(origChan) - for ep := range origChan { - e.acceptedChan <- ep + e.LockUser() + defer e.UnlockUser() + + if e.EndpointState() == StateListen && !e.closed { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + if e.acceptedChan == nil { + // listen is called after shutdown. + e.acceptedChan = make(chan *endpoint, backlog) + e.shutdownFlags = 0 + e.rcvListMu.Lock() + e.rcvClosed = false + e.rcvListMu.Unlock() + } else { + // Adjust the size of the channel iff we can fix + // existing pending connections into the new one. + if len(e.acceptedChan) > backlog { + return tcpip.ErrInvalidEndpointState + } + if cap(e.acceptedChan) == backlog { + return nil + } + origChan := e.acceptedChan + e.acceptedChan = make(chan *endpoint, backlog) + close(origChan) + for ep := range origChan { + e.acceptedChan <- ep + } } + + // Notify any blocked goroutines that they can attempt to + // deliver endpoints again. + e.acceptCond.Broadcast() + return nil } + if e.EndpointState() == StateInitial { + // The listen is called on an unbound socket, the socket is + // automatically bound to a random free port with the local + // address set to INADDR_ANY. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return err + } + } + // Endpoint must be bound before it can transition to listen mode. - if e.state != StateBound { + if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true - e.state = StateListen + e.setEndpointState(StateListen) + + // The channel may be non-nil when we're restoring the endpoint, and it + // may be pre-populated with some previously accepted (but not Accepted) + // endpoints. + e.acceptMu.Lock() if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } - e.workerRunning = true + e.acceptMu.Unlock() + e.workerRunning = true go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) - return nil } // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. -func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { - e.waiterQueue = waiterQueue +func (e *endpoint) startAcceptedLoop() { e.workerRunning = true - go e.protocolMainLoop(false) // S/R-SAFE: drained on save. + e.mu.Unlock() + wakerInitDone := make(chan struct{}) + go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save. + <-wakerInitDone } // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() // Endpoint must be in listen state before it can accept connections. - if e.state != StateListen { + if rcvClosed || e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } // Get the new accepted endpoint. + e.acceptMu.Lock() + defer e.acceptMu.Unlock() var n *endpoint select { case n = <-e.acceptedChan: + e.acceptCond.Signal() default: return nil, nil, tcpip.ErrWouldBlock } - - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() + + return e.bindLocked(addr) +} +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrAlreadyBound } e.BindAddr = addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -1935,26 +2505,30 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func(bindToDevice tcpip.NICID) { + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 e.ID.LocalAddress = "" e.boundNICID = 0 + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } - }(e.bindToDevice) + }(e.boundPortFlags, e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -1968,16 +2542,20 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } + if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + return err + } + // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) return nil } // GetLocalAddress returns the address to which the endpoint is bound. func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() return tcpip.FullAddress{ Addr: e.ID.LocalAddress, @@ -1988,10 +2566,10 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { // GetRemoteAddress returns the address to which the endpoint is connected. func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() - if !e.state.connected() { + if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -2002,45 +2580,26 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) - if !s.parse() { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - s.decRef() - return - } - - if !s.csumValid { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - s.decRef() - return - } - - e.stack.Stats().TCP.ValidSegmentsReceived.Increment() - e.stats.SegmentsReceived.Increment() - if (s.flags & header.TCPFlagRst) != 0 { - e.stack.Stats().TCP.ResetsReceived.Increment() - } +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + // TCP HandlePacket is not required anymore as inbound packets first + // land at the Dispatcher which then can either delivery using the + // worker go routine or directly do the invoke the tcp processing inline + // based on the state of the endpoint. +} +func (e *endpoint) enqueueSegment(s *segment) bool { // Send packet to worker goroutine. - if e.segmentQueue.enqueue(s) { - e.newSegmentWaker.Assert() - } else { + if !e.segmentQueue.enqueue(s) { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.SegmentQueueDropped.Increment() - s.decRef() + return false } + return true } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -2051,6 +2610,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.sndBufMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) + + case stack.ControlNoRoute: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNoRoute + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) + + case stack.ControlNetworkUnreachable: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNetworkUnreachable + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) } } @@ -2079,20 +2650,16 @@ func (e *endpoint) readyToRead(s *segment) { if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() - // Check if the receive window is now closed. If so make sure - // we set the zero window before we deliver the segment to ensure - // that a subsequent read of the segment will correctly trigger - // a non-zero notification. - if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + // Increase counter if the receive window falls down below MSS + // or half receive buffer size, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - e.zeroWindow = true } e.rcvList.PushBack(s) } else { e.rcvClosed = true } e.rcvListMu.Unlock() - e.waiterQueue.Notify(waiter.EventIn) } @@ -2156,8 +2723,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { - e.recentTS = tsVal + if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.setRecentTimestamp(tsVal) } } @@ -2167,22 +2734,21 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { e.sendTSOk = true - e.recentTS = synOpts.TSVal + e.setRecentTimestamp(synOpts.TSVal) } } // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(e.tsOffset) + return tcpTimeStamp(time.Now(), e.tsOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(offset uint32) uint32 { - now := time.Now() - return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { + return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending @@ -2236,9 +2802,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SegTime = time.Now() // Copy EndpointID. - e.mu.Lock() s.ID = stack.TCPEndpointID(e.ID) - e.mu.Unlock() // Copy endpoint rcv state. e.rcvListMu.Lock() @@ -2256,7 +2820,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Endpoint TCP Option state. s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTS + s.RecentTS = e.recentTimestamp() s.TSOffset = e.tsOffset s.SACKPermitted = e.sackPermitted s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) @@ -2327,6 +2891,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState { WEst: cubic.wEst, } } + + rc := e.snd.rc + s.Sender.RACKState = stack.TCPRACKState{ + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + } return s } @@ -2363,17 +2935,15 @@ func (e *endpoint) initGSO() { // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { - e.mu.RLock() + e.LockUser() // Make a copy of the endpoint info. ret := e.EndpointInfo - e.mu.RUnlock() + e.UnlockUser() return &ret } @@ -2382,6 +2952,18 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } -func mssForRoute(r *stack.Route) uint16 { - return uint16(r.MTU() - header.TCPMinimumSize) +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.LockUser() + running := e.workerRunning + e.UnlockUser() + if !running { + break + } + <-notifyCh + } } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index eae17237e..723e47ddc 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,9 +16,10 @@ package tcp import ( "fmt" - "sync" + "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -48,11 +49,10 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.state { - case StateInitial, StateBound: - // TODO(b/138137272): this enumeration duplicates - // EndpointState.connected. remove it. - case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + epState := e.EndpointState() + switch { + case epState == StateInitial || epState == StateBound: + case epState.connected() || epState.handshake(): if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) @@ -68,35 +68,31 @@ func (e *endpoint) beforeSave() { break } fallthrough - case StateListen, StateConnecting: + case epState == StateListen || epState == StateConnecting: e.drainSegmentLocked() - if e.state != StateClose && e.state != StateError { + // Refresh epState, since drainSegmentLocked may have changed it. + epState = e.EndpointState() + if !epState.closed() { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } - break } - fallthrough - case StateError, StateClose: - for e.state == StateError && e.workerRunning { + case epState.closed(): + for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() } if e.workerRunning { - panic("endpoint still has worker running in closed or error state") + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) } default: - panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) + panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) } if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - - if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { - panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") - } } // saveAcceptedChan is invoked by stateify. @@ -135,7 +131,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { // saveState is invoked by stateify. func (e *endpoint) saveState() EndpointState { - return e.state + return e.EndpointState() } // Endpoint loading must be done in the following ordering by their state, to @@ -148,23 +144,34 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state EndpointState) { +func (e *endpoint) loadState(epState EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. - if state.connected() { + // For restore purposes we treat TimeWait like a connected endpoint. + if epState.connected() || epState == StateTimeWait { connectedLoading.Add(1) } - switch state { - case StateListen: + switch { + case epState == StateListen: listenLoading.Add(1) - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): connectingLoading.Add(1) } - e.state = state + // Directly update the state here rather than using e.setEndpointState + // as the endpoint is still being loaded and the stack reference is not + // yet initialized. + atomic.StoreUint32((*uint32)(&e.state), uint32(epState)) } // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial + // Condition variables and mutexs are not S/R'ed so reinitialize + // acceptCond with e.acceptMu. + e.acceptCond = sync.NewCond(&e.acceptMu) stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -172,34 +179,40 @@ func (e *endpoint) afterLoad() { func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - - state := e.state - switch state { + epState := e.origEndpointState + switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) } } } bind := func() { - e.state = StateInitial - if len(e.BindAddr) == 0 { - e.BindAddr = e.ID.LocalAddress + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + if err != nil { + panic("unable to parse BindAddr: " + err.String()) } - if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok { + panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) } + e.isPortReserved = true + + // Mark endpoint as bound. + e.setEndpointState(StateBound) } - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + switch { + case epState.connected(): bind() if len(e.connectingAddress) == 0 { e.connectingAddress = e.ID.RemoteAddress @@ -217,8 +230,18 @@ func (e *endpoint) Resume(s *stack.Stack) { if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if epState == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() - case StateListen: + case epState == StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -227,10 +250,15 @@ func (e *endpoint) Resume(s *stack.Stack) { if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } + e.LockUser() + if e.shutdownFlags != 0 { + e.shutdownLocked(e.shutdownFlags) + } + e.UnlockUser() listenLoading.Done() tcpip.AsyncLoading.Done() }() - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -242,7 +270,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case StateBound: + case epState == StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -251,20 +279,14 @@ func (e *endpoint) Resume(s *stack.Stack) { bind() tcpip.AsyncLoading.Done() }() - case StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.state = StateClose - tcpip.AsyncLoading.Done() - }() - } - fallthrough - case StateError: + case epState == StateClose: + e.isPortReserved = false + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) + case epState == StateError: + e.state = StateError + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } @@ -284,7 +306,17 @@ func (e *endpoint) loadLastError(s string) { return } - e.lastError = loadError(s) + e.lastError = tcpip.StringToError(s) +} + +// saveRecentTSTime is invoked by stateify. +func (e *endpoint) saveRecentTSTime() unixTime { + return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} +} + +// loadRecentTSTime is invoked by stateify. +func (e *endpoint) loadRecentTSTime(unix unixTime) { + e.recentTSTime = time.Unix(unix.second, unix.nano) } // saveHardError is invoked by stateify. @@ -302,71 +334,7 @@ func (e *EndpointInfo) loadHardError(s string) { return } - e.HardError = loadError(s) -} - -var messageToError map[string]*tcpip.Error - -var populate sync.Once - -func loadError(s string) *tcpip.Error { - populate.Do(func() { - var errors = []*tcpip.Error{ - tcpip.ErrUnknownProtocol, - tcpip.ErrUnknownNICID, - tcpip.ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID, - tcpip.ErrDuplicateAddress, - tcpip.ErrNoRoute, - tcpip.ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound, - tcpip.ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected, - tcpip.ErrNoPortAvailable, - tcpip.ErrPortInUse, - tcpip.ErrBadLocalAddress, - tcpip.ErrClosedForSend, - tcpip.ErrClosedForReceive, - tcpip.ErrWouldBlock, - tcpip.ErrConnectionRefused, - tcpip.ErrTimeout, - tcpip.ErrAborted, - tcpip.ErrConnectStarted, - tcpip.ErrDestinationRequired, - tcpip.ErrNotSupported, - tcpip.ErrQueueSizeNotSupported, - tcpip.ErrNotConnected, - tcpip.ErrConnectionReset, - tcpip.ErrConnectionAborted, - tcpip.ErrNoSuchFile, - tcpip.ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress, - tcpip.ErrBadAddress, - tcpip.ErrNetworkUnreachable, - tcpip.ErrMessageTooLong, - tcpip.ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled, - tcpip.ErrNotPermitted, - tcpip.ErrAddressFamilyNotSupported, - } - - messageToError = make(map[string]*tcpip.Error) - for _, e := range errors { - if messageToError[e.String()] != nil { - panic("tcpip errors with duplicated message: " + e.String()) - } - messageToError[e.String()] = e - } - }) - - e, ok := messageToError[s] - if !ok { - panic("unknown error message: " + s) - } - - return e + e.HardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 63666f0b3..070b634b4 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -15,10 +15,8 @@ package tcp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -63,8 +61,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() // We only care about well-formed SYN packets. @@ -132,7 +130,7 @@ func (r *ForwarderRequest) Complete(sendReset bool) { // If the caller requested, send a reset. if sendReset { - replyWithReset(r.segment) + replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) } // Release all resources. @@ -159,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }) + }, queue, nil) if err != nil { return nil, err } // Start the protocol goroutine. - ep.startAcceptedLoop(queue) + ep.startAcceptedLoop() return ep, nil } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index db40785d3..c5afa2680 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,9 +21,11 @@ package tcp import ( + "runtime" "strings" - "sync" + "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -54,41 +56,149 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // MaxTCPLingerTimeout is the maximum amount of time that sockets + // linger in FIN_WAIT_2 state before being marked closed. + MaxTCPLingerTimeout = 120 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second + + // DefaultSynRetries is the default value for the number of SYN retransmits + // before a connect is aborted. + DefaultSynRetries = 6 +) + +const ( + ccReno = "reno" + ccCubic = "cubic" ) -// SACKEnabled option can be used to enable SACK support in the TCP -// protocol. See: https://tools.ietf.org/html/rfc2018. +// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to +// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool -// SendBufferSizeOption allows the default, min and max send buffer sizes for -// TCP endpoints to be queried or configured. +// Recovery is used by stack.(*Stack).TransportProtocolOption to +// set loss detection algorithm in TCP. +type Recovery int32 + +const ( + // RACKLossDetection indicates RACK is used for loss detection and + // recovery. + RACKLossDetection Recovery = 1 << iota + + // RACKStaticReoWnd indicates the reordering window should not be + // adjusted when DSACK is received. + RACKStaticReoWnd + + // RACKNoDupTh indicates RACK should not consider the classic three + // duplicate acknowledgements rule to mark the segments as lost. This + // is used when reordering is not detected. + RACKNoDupTh +) + +// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to +// enable/disable Nagle's algorithm in TCP. +type DelayEnabled bool + +// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption +// to get/set the default, min and max TCP send buffer sizes. type SendBufferSizeOption struct { Min int Default int Max int } -// ReceiveBufferSizeOption allows the default, min and max receive buffer size -// for TCP endpoints to be queried or configured. +// ReceiveBufferSizeOption is used by +// stack.(Stack*).TransportProtocolOption to get/set the default, min and max +// TCP receive buffer sizes. type ReceiveBufferSizeOption struct { Min int Default int Max int } -const ( - ccReno = "reno" - ccCubic = "cubic" -) +// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The +// value is protected by a mutex so that we can increment only when it's +// guaranteed not to go above a threshold. +type synRcvdCounter struct { + sync.Mutex + value uint64 + pending sync.WaitGroup + threshold uint64 +} + +// inc tries to increment the global number of endpoints in SYN-RCVD state. It +// succeeds if the increment doesn't make the count go beyond the threshold, and +// fails otherwise. +func (s *synRcvdCounter) inc() bool { + s.Lock() + defer s.Unlock() + if s.value >= s.threshold { + return false + } + + s.pending.Add(1) + s.value++ + + return true +} + +// dec atomically decrements the global number of endpoints in SYN-RCVD +// state. It must only be called if a previous call to inc succeeded. +func (s *synRcvdCounter) dec() { + s.Lock() + defer s.Unlock() + s.value-- + s.pending.Done() +} + +// synCookiesInUse returns true if the synRcvdCount is greater than +// SynRcvdCountThreshold. +func (s *synRcvdCounter) synCookiesInUse() bool { + s.Lock() + defer s.Unlock() + return s.value >= s.threshold +} + +// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. +func (s *synRcvdCounter) SetThreshold(threshold uint64) { + s.Lock() + defer s.Unlock() + s.threshold = threshold +} + +// Threshold returns the current value of synRcvdCounter.Threhsold. +func (s *synRcvdCounter) Threshold() uint64 { + s.Lock() + defer s.Unlock() + return s.threshold +} type protocol struct { - mu sync.Mutex + mu sync.RWMutex sackEnabled bool + recovery Recovery + delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + lingerTimeout time.Duration + timeWaitTimeout time.Duration + timeWaitReuse tcpip.TCPTimeWaitReuseOption + minRTO time.Duration + maxRTO time.Duration + maxRetries uint32 + synRcvdCount synRcvdCounter + synRetries uint8 + dispatcher dispatcher } // Number returns the tcp protocol number. @@ -97,7 +207,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, waiterQueue), nil } @@ -119,6 +229,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } +// QueuePacket queues packets targeted at an endpoint after hashing the packet +// to a specific processing queue. Each queue is serviced by its own processor +// goroutine which is responsible for dequeuing and doing full TCP dispatch of +// the packet. +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + p.dispatcher.queuePacket(r, ep, id, pkt) +} + // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. // @@ -126,8 +244,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { @@ -139,24 +257,45 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo return true } - replyWithReset(s) + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) return true } // replyWithReset replies to the given segment with a reset segment. -func replyWithReset(s *segment) { +func replyWithReset(s *segment, tos, ttl uint8) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. if s.flagIsSet(header.TCPFlagAck) { seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) } - - ack := s.sequenceNumber.Add(s.logicalLen()) - - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, tcpFields{ + id: s.id, + ttl: ttl, + tos: tos, + flags: flags, + seq: seq, + ack: ack, + rcvWnd: 0, + }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */) } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { case SACKEnabled: @@ -165,6 +304,18 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case Recovery: + p.mu.Lock() + p.recovery = Recovery(v) + p.mu.Unlock() + return nil + + case DelayEnabled: + p.mu.Lock() + p.delayEnabled = bool(v) + p.mu.Unlock() + return nil + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue @@ -202,48 +353,174 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.lingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.timeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitReuseOption: + if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.timeWaitReuse = v + p.mu.Unlock() + return nil + + case tcpip.TCPMinRTOOption: + if v < 0 { + v = tcpip.TCPMinRTOOption(MinRTO) + } + p.mu.Lock() + p.minRTO = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPMaxRTOOption: + if v < 0 { + v = tcpip.TCPMaxRTOOption(MaxRTO) + } + p.mu.Lock() + p.maxRTO = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPMaxRetriesOption: + p.mu.Lock() + p.maxRetries = uint32(v) + p.mu.Unlock() + return nil + + case tcpip.TCPSynRcvdCountThresholdOption: + p.mu.Lock() + p.synRcvdCount.SetThreshold(uint64(v)) + p.mu.Unlock() + return nil + + case tcpip.TCPSynRetriesOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.synRetries = uint8(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: - p.mu.Lock() + p.mu.RLock() *v = SACKEnabled(p.sackEnabled) - p.mu.Unlock() + p.mu.RUnlock() + return nil + + case *Recovery: + p.mu.RLock() + *v = Recovery(p.recovery) + p.mu.RUnlock() + return nil + + case *DelayEnabled: + p.mu.RLock() + *v = DelayEnabled(p.delayEnabled) + p.mu.RUnlock() return nil case *SendBufferSizeOption: - p.mu.Lock() + p.mu.RLock() *v = p.sendBufferSize - p.mu.Unlock() + p.mu.RUnlock() return nil case *ReceiveBufferSizeOption: - p.mu.Lock() + p.mu.RLock() *v = p.recvBufferSize - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.CongestionControlOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.CongestionControlOption(p.congestionControl) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.AvailableCongestionControlOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.ModerateReceiveBufferOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) - p.mu.Unlock() + p.mu.RUnlock() + return nil + + case *tcpip.TCPLingerTimeoutOption: + p.mu.RLock() + *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout) + p.mu.RUnlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.RLock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout) + p.mu.RUnlock() + return nil + + case *tcpip.TCPTimeWaitReuseOption: + p.mu.RLock() + *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) + p.mu.RUnlock() + return nil + + case *tcpip.TCPMinRTOOption: + p.mu.RLock() + *v = tcpip.TCPMinRTOOption(p.minRTO) + p.mu.RUnlock() + return nil + + case *tcpip.TCPMaxRTOOption: + p.mu.RLock() + *v = tcpip.TCPMaxRTOOption(p.maxRTO) + p.mu.RUnlock() + return nil + + case *tcpip.TCPMaxRetriesOption: + p.mu.RLock() + *v = tcpip.TCPMaxRetriesOption(p.maxRetries) + p.mu.RUnlock() + return nil + + case *tcpip.TCPSynRcvdCountThresholdOption: + p.mu.RLock() + *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + p.mu.RUnlock() + return nil + + case *tcpip.TCPSynRetriesOption: + p.mu.RLock() + *v = tcpip.TCPSynRetriesOption(p.synRetries) + p.mu.RUnlock() return nil default: @@ -251,12 +528,67 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + +// SynRcvdCounter returns a reference to the synRcvdCount for this protocol +// instance. +func (p *protocol) SynRcvdCounter() *synRcvdCounter { + return &p.synRcvdCount +} + +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset + } + + _, ok = pkt.TransportHeader().Consume(hdrLen) + return ok +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ - sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, + p := protocol{ + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultSendBufferSize, + Max: MaxBufferSize, + }, + recvBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultReceiveBufferSize, + Max: MaxBufferSize, + }, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, + lingerTimeout: DefaultTCPLingerTimeout, + timeWaitTimeout: DefaultTCPTimeWaitTimeout, + timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly, + synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, + synRetries: DefaultSynRetries, + minRTO: MinRTO, + maxRTO: MaxRTO, + maxRetries: MaxRetries, + recovery: RACKLossDetection, } + p.dispatcher.init(runtime.GOMAXPROCS(0)) + return &p } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go new file mode 100644 index 000000000..d969ca23a --- /dev/null +++ b/pkg/tcpip/transport/tcp/rack.go @@ -0,0 +1,82 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// RACK is a loss detection algorithm used in TCP to detect packet loss and +// reordering using transmission timestamp of the packets instead of packet or +// sequence counts. To use RACK, SACK should be enabled on the connection. + +// rackControl stores the rack related fields. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1 +// +// +stateify savable +type rackControl struct { + // xmitTime is the latest transmission timestamp of rackControl.seg. + xmitTime time.Time `state:".(unixTime)"` + + // endSequence is the ending TCP sequence number of rackControl.seg. + endSequence seqnum.Value + + // fack is the highest selectively or cumulatively acknowledged + // sequence. + fack seqnum.Value + + // rtt is the RTT of the most recently delivered packet on the + // connection (either cumulatively acknowledged or selectively + // acknowledged) that was not marked invalid as a possible spurious + // retransmission. + rtt time.Duration +} + +// Update will update the RACK related fields when an ACK has been received. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) { + rtt := time.Now().Sub(seg.xmitTime) + + // If the ACK is for a retransmitted packet, do not update if it is a + // spurious inference which is determined by below checks: + // 1. When Timestamping option is available, if the TSVal is less than the + // transmit time of the most recent retransmitted packet. + // 2. When RTT calculated for the packet is less than the smoothed RTT + // for the connection. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // step 2 + if seg.xmitCount > 1 { + if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { + if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) { + return + } + } + if rtt < srtt { + return + } + } + + rc.rtt = rtt + // Update rc.xmitTime and rc.endSequence to the transmit time and + // ending sequence number of the packet which has been acknowledged + // most recently. + endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { + rc.xmitTime = seg.xmitTime + rc.endSequence = endSeq + } +} diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go new file mode 100644 index 000000000..c9dc7e773 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rack_state.go @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveXmitTime is invoked by stateify. +func (rc *rackControl) saveXmitTime() unixTime { + return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (rc *rackControl) loadXmitTime(unix unixTime) { + rc.xmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..5e0bfe585 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -49,29 +50,36 @@ type receiver struct { pendingRcvdSegments segmentHeap pendingBufUsed seqnum.Size pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), - rcvWnd: rcvWnd, - rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), } } // acceptable checks if the segment sequence number range is acceptable // according to the table on page 26 of RFC 793. func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - rcvWnd := r.rcvNxt.Size(r.rcvAcc) - if rcvWnd == 0 { - return segLen == 0 && segSeq == r.rcvNxt + // r.rcvWnd could be much larger than the window size we advertised in our + // outgoing packets, we should use what we have advertised for acceptability + // test. + scaledWindowSize := r.rcvWnd >> r.rcvWndScale + if scaledWindowSize > 0xffff { + // This is what we actually put in the Window field. + scaledWindowSize = 0xffff } - - return segSeq.InWindow(r.rcvNxt, rcvWnd) || - seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen) + advertisedWindowSize := scaledWindowSize << r.rcvWndScale + return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) } // getSendParams returns the parameters needed by the sender when building @@ -93,12 +101,6 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // in such cases we may need to send an ack to indicate to our peer that it can // resume sending data. func (r *receiver) nonZeroWindow() { - if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 { - // We never got around to announcing a zero window size, so we - // don't need to immediately announce a nonzero one. - return - } - // Immediately send an ack. r.ep.snd.sendAck() } @@ -169,22 +171,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. - r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateEstablished: - r.ep.state = StateCloseWait + r.ep.setEndpointState(StateCloseWait) case StateFinWait1: if s.flagIsSet(header.TCPFlagAck) { // FIN-ACK, transition to TIME-WAIT. - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } else { // Simultaneous close, expecting a final ACK. - r.ep.state = StateClosing + r.ep.setEndpointState(StateClosing) } case StateFinWait2: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } - r.ep.mu.Unlock() // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the @@ -196,6 +196,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum for i := first; i < len(r.pendingRcvdSegments); i++ { r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of + // truncated items, thus truncated items must be set to nil to avoid + // memory leaks. + r.pendingRcvdSegments[i] = nil } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] @@ -204,17 +208,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) { - r.ep.mu.Lock() - switch r.ep.state { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + switch r.ep.EndpointState() { case StateFinWait1: - r.ep.state = StateFinWait2 + r.ep.setEndpointState(StateFinWait2) + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) case StateLastAck: - r.ep.state = StateClose + r.ep.transitionToStateCloseLocked() } - r.ep.mu.Unlock() } return true @@ -253,32 +260,119 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait: + // If the ACK acks something not yet sent then we send an ACK. + if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + r.ep.snd.sendAck() + return true, nil + } + fallthrough + case StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + state := r.ep.EndpointState() + closed := r.ep.closed segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil + } + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { // We only store the segment if it's within our buffer // size limit. if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += s.logicalLen() + r.pendingBufUsed += seqnum.Size(s.segMemSize()) s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -288,7 +382,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -312,7 +406,70 @@ func (r *receiver) handleRcvdSegment(s *segment) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= s.logicalLen() + r.pendingBufUsed -= seqnum.Size(s.segMemSize()) s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/transport/tcp/rcv_state.go index 19a7f77e3..2bf21a2e7 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -12,24 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This file contains various Targets. +package tcp -package iptables +import ( + "time" +) -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// UnconditionalAcceptTarget accepts all packets. -type UnconditionalAcceptTarget struct{} - -// Action implements Target.Action. -func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, string) { - return Accept, "" +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} } -// UnconditionalDropTarget denies all packets. -type UnconditionalDropTarget struct{} - -// Action implements Target.Action. -func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) { - return Drop, "" +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) } diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go new file mode 100644 index 000000000..8a026ec46 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rcv_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +func TestAcceptable(t *testing.T) { + for _, tt := range []struct { + segSeq seqnum.Value + segLen seqnum.Size + rcvNxt, rcvAcc seqnum.Value + want bool + }{ + // The segment is smaller than the window. + {105, 2, 100, 104, false}, + {105, 2, 101, 105, true}, + {105, 2, 102, 106, true}, + {105, 2, 103, 107, true}, + {105, 2, 104, 108, true}, + {105, 2, 105, 109, true}, + {105, 2, 106, 110, true}, + {105, 2, 107, 111, false}, + + // The segment is larger than the window. + {105, 4, 103, 105, true}, + {105, 4, 104, 106, true}, + {105, 4, 105, 107, true}, + {105, 4, 106, 108, true}, + {105, 4, 107, 109, true}, + {105, 4, 108, 110, true}, + {105, 4, 109, 111, false}, + {105, 4, 110, 112, false}, + + // The segment has no width. + {105, 0, 100, 102, false}, + {105, 0, 101, 103, false}, + {105, 0, 102, 104, false}, + {105, 0, 103, 105, true}, + {105, 0, 104, 106, true}, + {105, 0, 105, 107, true}, + {105, 0, 106, 108, false}, + {105, 0, 107, 109, false}, + + // The receive window has no width. + {105, 2, 103, 103, false}, + {105, 2, 104, 104, false}, + {105, 2, 105, 105, false}, + {105, 2, 106, 106, false}, + {105, 2, 107, 107, false}, + {105, 2, 108, 108, false}, + {105, 2, 109, 109, false}, + } { + if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { + t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) + } + } +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index ea725d513..94307d31a 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -35,6 +35,7 @@ type segment struct { id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -55,18 +56,19 @@ type segment struct { options []byte `state:".([]byte)"` hasNewSACKInfo bool rcvdTime time.Time `state:".(unixTime)"` - // xmitTime is the last transmit time of this segment. A zero value - // indicates that the segment has yet to be transmitted. - xmitTime time.Time `state:".(unixTime)"` + // xmitTime is the last transmit time of this segment. + xmitTime time.Time `state:".(unixTime)"` + xmitCount uint32 } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } - s.data = vv.Clone(s.views[:]) + s.data = pkt.Data.Clone(s.views[:]) + s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() return s } @@ -77,9 +79,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V id: id, route: r.Clone(), } - s.views[0] = v - s.data = buffer.NewVectorisedView(len(v), s.views[:1]) s.rcvdTime = time.Now() + if len(v) != 0 { + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + } return s } @@ -94,13 +98,21 @@ func (s *segment) clone() *segment { route: s.route.Clone(), viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, + xmitTime: s.xmitTime, + xmitCount: s.xmitCount, } t.data = s.data.Clone(t.views[:]) return t } -func (s *segment) flagIsSet(flag uint8) bool { - return (s.flags & flag) != 0 +// flagIsSet checks if at least one flag in flags is set in s.flags. +func (s *segment) flagIsSet(flags uint8) bool { + return s.flags&flags != 0 +} + +// flagsAreSet checks if all flags in flags are set in s.flags. +func (s *segment) flagsAreSet(flags uint8) bool { + return s.flags&flags == flags } func (s *segment) decRef() { @@ -126,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// segMemSize is the amount of memory used to hold the segment data and +// the associated metadata. +func (s *segment) segMemSize() int { + return segSize + s.data.Size() +} + // parse populates the sequence & ack numbers, flags, and window fields of the // segment from the TCP header stored in the data. It then updates the view to // skip the header. @@ -136,8 +154,6 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) - // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -148,12 +164,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(s.hdr.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(s.hdr) { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -162,21 +178,19 @@ func (s *segment) parse() bool { if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { s.csumValid = true verifyChecksum = false - s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) - s.data.TrimFront(offset) + s.csum = s.hdr.Checksum() + xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(s.hdr.AckNumber()) + s.flags = s.hdr.Flags() + s.window = seqnum.Size(s.hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go index 9fd061d7d..8d3ddce4b 100644 --- a/pkg/tcpip/transport/tcp/segment_heap.go +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -14,21 +14,25 @@ package tcp +import "container/heap" + type segmentHeap []*segment +var _ heap.Interface = (*segmentHeap)(nil) + // Len returns the length of h. -func (h segmentHeap) Len() int { - return len(h) +func (h *segmentHeap) Len() int { + return len(*h) } // Less determines whether the i-th element of h is less than the j-th element. -func (h segmentHeap) Less(i, j int) bool { - return h[i].sequenceNumber.LessThan(h[j].sequenceNumber) +func (h *segmentHeap) Less(i, j int) bool { + return (*h)[i].sequenceNumber.LessThan((*h)[j].sequenceNumber) } // Swap swaps the i-th and j-th elements of h. -func (h segmentHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] +func (h *segmentHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } // Push adds x as the last element of h. @@ -41,6 +45,7 @@ func (h *segmentHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] + old[n-1] = nil *h = old[:n-1] return x } diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index e0759225e..48a257137 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -15,7 +15,7 @@ package tcp import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // segmentQueue is a bounded, thread-safe queue of TCP segments. @@ -28,10 +28,16 @@ type segmentQueue struct { used int } +// emptyLocked determines if the queue is empty. +// Preconditions: q.mu must be held. +func (q *segmentQueue) emptyLocked() bool { + return q.used == 0 +} + // empty determines if the queue is empty. func (q *segmentQueue) empty() bool { q.mu.Lock() - r := q.used == 0 + r := q.emptyLocked() q.mu.Unlock() return r diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go new file mode 100644 index 000000000..0ab7b8f56 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_unsafe.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "unsafe" +) + +const ( + segSize = int(unsafe.Sizeof(segment{})) +) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index d3f7c9125..c55589c45 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -15,12 +15,13 @@ package tcp import ( + "fmt" "math" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -28,8 +29,11 @@ import ( ) const ( - // minRTO is the minimum allowed value for the retransmit timeout. - minRTO = 200 * time.Millisecond + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second // InitialCwnd is the initial congestion window. InitialCwnd = 10 @@ -37,6 +41,11 @@ const ( // nDupAckThreshold is the number of duplicate ACK's required // before fast-retransmit is entered. nDupAckThreshold = 3 + + // MaxRetries is the maximum number of probe retries sender does + // before timing out the connection. + // Linux default TCP_RETR2, net.ipv4.tcp_retries2. + MaxRetries = 15 ) // ccState indicates the current congestion control state for this sender. @@ -123,10 +132,6 @@ type sender struct { // sndNxt is the sequence number of the next segment to be sent. sndNxt seqnum.Value - // sndNxtList is the sequence number of the next segment to be added to - // the send list. - sndNxtList seqnum.Value - // rttMeasureSeqNum is the sequence number being used for the latest RTT // measurement. rttMeasureSeqNum seqnum.Value @@ -134,6 +139,18 @@ type sender struct { // rttMeasureTime is the time when the rttMeasureSeqNum was sent. rttMeasureTime time.Time `state:".(unixTime)"` + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + + // zeroWindowProbing is set if the sender is currently probing + // for zero receive window. + zeroWindowProbing bool `state:"nosave"` + + // unackZeroWindowProbes is the number of unacknowledged zero + // window probes. + unackZeroWindowProbes uint32 `state:"nosave"` + closed bool writeNext *segment writeList segmentList @@ -146,6 +163,15 @@ type sender struct { rtt rtt rto time.Duration + // minRTO is the minimum permitted value for sender.rto. + minRTO time.Duration + + // maxRTO is the maximum permitted value for sender.rto. + maxRTO time.Duration + + // maxRetries is the maximum permitted retransmissions. + maxRetries uint32 + // maxPayloadSize is the maximum size of the payload of a given segment. // It is initialized on demand. maxPayloadSize int @@ -165,6 +191,10 @@ type sender struct { // cc is the congestion control algorithm in use for this sender. cc congestionControl + + // rc has the fields needed for implementing RACK loss detection + // algorithm. + rc rackControl } // rtt is a synchronization wrapper used to appease stateify. See the comment @@ -222,7 +252,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint sndWnd: sndWnd, sndUna: iss + 1, sndNxt: iss + 1, - sndNxtList: iss + 1, rto: 1 * time.Second, rttMeasureSeqNum: iss + 1, lastSendTime: time.Now(), @@ -258,6 +287,25 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // etc. s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + // Get Stack wide config. + var minRTO tcpip.TCPMinRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &minRTO); err != nil { + panic(fmt.Sprintf("unable to get minRTO from stack: %s", err)) + } + s.minRTO = time.Duration(minRTO) + + var maxRTO tcpip.TCPMaxRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRTO); err != nil { + panic(fmt.Sprintf("unable to get maxRTO from stack: %s", err)) + } + s.maxRTO = time.Duration(maxRTO) + + var maxRetries tcpip.TCPMaxRetriesOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRetries); err != nil { + panic(fmt.Sprintf("unable to get maxRetries from stack: %s", err)) + } + s.maxRetries = uint32(maxRetries) + return s } @@ -392,8 +440,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < minRTO { - s.rto = minRTO + if s.rto < s.minRTO { + s.rto = s.minRTO } } @@ -435,17 +483,56 @@ func (s *sender) retransmitTimerExpired() bool { return true } + // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases + // when writeList is empty. Remove this once we have a proper fix for this + // issue. + if s.writeList.Front() == nil { + return true + } + s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() - // Give up if we've waited more than a minute since the last resend. - if s.rto >= 60*time.Second { + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + uto := s.ep.userTimeout + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } + + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := s.maxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + // Always honor the user-timeout irrespective of whether the zero + // window probes were acknowledged. + // net/ipv4/tcp_timer.c::tcp_probe_timer() + if remaining <= 0 || s.unackZeroWindowProbes >= s.maxRetries { return false } // Set new timeout. The timer will be restarted by the call to sendData // below. s.rto *= 2 + // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5 + if s.rto > s.maxRTO { + s.rto = s.maxRTO + } + + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. // @@ -488,6 +575,26 @@ func (s *sender) retransmitTimerExpired() bool { // information is usable after an RTO. s.ep.scoreboard.Reset() s.writeNext = s.writeList.Front() + + // RFC 1122 4.2.2.17: Start sending zero window probes when we still see a + // zero receive window after retransmission interval and we have data to + // send. + if s.zeroWindowProbing { + s.sendZeroWindowProbe() + // RFC 1122 4.2.2.17: A TCP MAY keep its offered receive window closed + // indefinitely. As long as the receiving TCP continues to send + // acknowledgments in response to the probe segments, the sending TCP + // MUST allow the connection to stay open. + return true + } + + seg := s.writeNext + // RFC 1122 4.2.3.5: Close the connection when the number of + // retransmissions for this segment is beyond a limit. + if seg != nil && seg.xmitCount > s.maxRetries { + return false + } + s.sendData() return true @@ -515,25 +622,51 @@ func (s *sender) splitSeg(seg *segment, size int) { nSeg.data.TrimFront(size) nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) s.writeList.InsertAfter(seg, nSeg) + + // The segment being split does not carry PUSH flag because it is + // followed by the newly split segment. + // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered + // segment (i.e., when there is no more queued data to be sent). + // Linux removes PSH flag only when the segment is being split over MSS + // and retains it when we are splitting the segment over lack of sender + // window space. + // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() + if seg.data.Size() > s.maxPayloadSize { + seg.flags ^= header.TCPFlagPsh + } + seg.data.CapLength(size) } -// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that -// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2 -// is handled by the normal send logic. -func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { +// NextSeg implements the RFC6675 NextSeg() operation. +// +// NextSeg starts scanning the writeList starting from nextSegHint and returns +// the hint to be passed on the next call to NextSeg. This is required to avoid +// iterating the write list repeatedly when NextSeg is invoked in a loop during +// recovery. The returned hint will be nil if there are no more segments that +// can match rules defined by NextSeg operation in RFC6675. +// +// rescueRtx will be true only if nextSeg is a rescue retransmission as +// described by Step 4) of the NextSeg algorithm. +func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRtx bool) { var s3 *segment var s4 *segment - smss := s.ep.scoreboard.SMSS() // Step 1. - for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { - if !s.isAssignedSequenceNumber(seg) { + for seg := nextSegHint; seg != nil; seg = seg.Next() { + // Stop iteration if we hit a segment that has never been + // transmitted (i.e. either it has no assigned sequence number + // or if it does have one, it's >= the next sequence number + // to be sent [i.e. >= s.sndNxt]). + if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) { + hint = nil break } segSeq := seg.sequenceNumber - if seg.data.Size() > int(smss) { + if smss := s.ep.scoreboard.SMSS(); seg.data.Size() > int(smss) { s.splitSeg(seg, int(smss)) } + // See RFC 6675 Section 4 // // 1. If there exists a smallest unSACKED sequence number @@ -550,8 +683,9 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { // NextSeg(): // (1.c) IsLost(S2) returns true. if s.ep.scoreboard.IsLost(segSeq) { - return seg, s3, s4 + return seg, seg.Next(), false } + // NextSeg(): // // (3): If the conditions for rules (1) and (2) @@ -563,6 +697,7 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { // SHOULD be returned. if s3 == nil { s3 = seg + hint = seg.Next() } } // NextSeg(): @@ -571,10 +706,12 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { // but there exists outstanding unSACKED data, we // provide the opportunity for a single "rescue" // retransmission per entry into loss recovery. If - // HighACK is greater than RescueRxt, the one - // segment of upto SMSS octects that MUST include - // the highest outstanding unSACKed sequence number - // SHOULD be returned. + // HighACK is greater than RescueRxt (or RescueRxt + // is undefined), then one segment of upto SMSS + // octects that MUST include the highest outstanding + // unSACKed sequence number SHOULD be returned, and + // RescueRxt set to RecoveryPoint. HighRxt MUST NOT + // be updated. if s.fr.rescueRxt.LessThan(s.sndUna - 1) { if s4 != nil { if s4.sequenceNumber.LessThan(segSeq) { @@ -583,12 +720,31 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { } else { s4 = seg } - s.fr.rescueRxt = s.fr.last } } } - return nil, s3, s4 + // If we got here then no segment matched step (1). + // Step (2): "If no sequence number 'S2' per rule (1) + // exists but there exists available unsent data and the + // receiver's advertised window allows, the sequence + // range of one segment of up to SMSS octets of + // previously unsent data starting with sequence number + // HighData+1 MUST be returned." + for seg := s.writeNext; seg != nil; seg = seg.Next() { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + continue + } + // We do not split the segment here to <= smss as it has + // potentially not been assigned a sequence number yet. + return seg, nil, false + } + + if s3 != nil { + return s3, hint, false + } + + return s4, nil, true } // maybeSendSegment tries to send the specified segment and either coalesces @@ -601,7 +757,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(seg.sequenceNumber.Size(end)) + available := int(s.sndNxt.Size(end)) if available > limit { available = limit } @@ -644,8 +800,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // sent all at once. return false } - if atomic.LoadUint32(&s.ep.cork) != 0 { - // Hold back the segment until full. + // With TCP_CORK, hold back until minimum of the available + // send space and MSS. + // TODO(gvisor.dev/issue/2833): Drain the held segments after a + // timeout. + if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { return false } } @@ -664,18 +823,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) - // Transition to FIN-WAIT1 state since we're initiating an active close. - s.ep.mu.Lock() - switch s.ep.state { + // Update the state to reflect that we have now + // queued a FIN. + switch s.ep.EndpointState() { case StateCloseWait: - // We've already received a FIN and are now sending our own. The - // sender is now awaiting a final ACK for this FIN. - s.ep.state = StateLastAck + s.ep.setEndpointState(StateLastAck) default: - s.ep.state = StateFinWait1 + s.ep.setEndpointState(StateFinWait1) } - s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() - s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { @@ -690,10 +845,52 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if available == 0 { return false } + + // If the whole segment or at least 1MSS sized segment cannot + // be accomodated in the receiver advertized window, skip + // splitting and sending of the segment. ref: + // net/ipv4/tcp_output.c::tcp_snd_wnd_test() + // + // Linux checks this for all segment transmits not triggered by + // a probe timer. On this condition, it defers the segment split + // and transmit to a short probe timer. + // + // ref: include/net/tcp.h::tcp_check_probe_timer() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() + // + // Instead of defining a new transmit timer, we attempt to split + // the segment right here if there are no pending segments. If + // there are pending segments, segment transmits are deferred to + // the retransmit timer handler. + if s.sndUna != s.sndNxt { + switch { + case available >= seg.data.Size(): + // OK to send, the whole segments fits in the + // receiver's advertised window. + case available >= s.maxPayloadSize: + // OK to send, at least 1 MSS sized segment fits + // in the receiver's advertised window. + default: + return false + } + } + + // The segment size limit is computed as a function of sender + // congestion window and MSS. When sender congestion window is > + // 1, this limit can be larger than MSS. Ensure that the + // currently available send space is not greater than minimum of + // this limit and MSS. if available > limit { available = limit } + // If GSO is not in use then cap available to + // maxPayloadSize. When GSO is in use the gVisor GSO logic or + // the host GSO logic will cap the segment to the correct size. + if s.ep.gso == nil && available > s.maxPayloadSize { + available = s.maxPayloadSize + } + if seg.data.Size() > available { s.splitSeg(seg, available) } @@ -716,64 +913,47 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // section 5, step C. func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { s.SetPipe() + + if smss := int(s.ep.scoreboard.SMSS()); limit > smss { + // Cap segment size limit to s.smss as SACK recovery requires + // that all retransmissions or new segments send during recovery + // be of <= SMSS. + limit = smss + } + + nextSegHint := s.writeList.Front() for s.outstanding < s.sndCwnd { - nextSeg, s3, s4 := s.NextSeg() + var nextSeg *segment + var rescueRtx bool + nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint) if nextSeg == nil { - // NextSeg(): - // - // Step (2): "If no sequence number 'S2' per rule (1) - // exists but there exists available unsent data and the - // receiver's advertised window allows, the sequence - // range of one segment of up to SMSS octets of - // previously unsent data starting with sequence number - // HighData+1 MUST be returned." - for seg := s.writeNext; seg != nil; seg = seg.Next() { - if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { - continue - } - // Step C.3 described below is handled by - // maybeSendSegment which increments sndNxt when - // a segment is transmitted. - // - // Step C.3 "If any of the data octets sent in - // (C.1) are above HighData, HighData must be - // updated to reflect the transmission of - // previously unsent data." - if sent := s.maybeSendSegment(seg, limit, end); !sent { - break - } - dataSent = true - s.outstanding++ - s.writeNext = seg.Next() - nextSeg = seg - break - } - if nextSeg != nil { - continue - } - } - rescueRtx := false - if nextSeg == nil && s3 != nil { - nextSeg = s3 - } - if nextSeg == nil && s4 != nil { - nextSeg = s4 - rescueRtx = true + return dataSent } - if nextSeg == nil { - break - } - segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) - if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) { - // RFC 6675, Step C.2 + if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + // New data being sent. + + // Step C.3 described below is handled by + // maybeSendSegment which increments sndNxt when + // a segment is transmitted. // - // "If any of the data octets sent in (C.1) are below - // HighData, HighRxt MUST be set to the highest sequence - // number of the retransmitted segment unless NextSeg () - // rule (4) was invoked for this retransmission." - s.fr.highRxt = segEnd - 1 + // Step C.3 "If any of the data octets sent in + // (C.1) are above HighData, HighData must be + // updated to reflect the transmission of + // previously unsent data." + // + // We pass s.smss as the limit as the Step 2) requires that + // new data sent should be of size s.smss or less. + if sent := s.maybeSendSegment(nextSeg, limit, end); !sent { + return dataSent + } + dataSent = true + s.outstanding++ + s.writeNext = nextSeg.Next() + continue } + // Now handle the retransmission case where we matched either step 1,3 or 4 + // of the NextSeg algorithm. // RFC 6675, Step C.4. // // "The estimate of the amount of data outstanding in the network @@ -782,10 +962,54 @@ func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) s.outstanding++ dataSent = true s.sendSegment(nextSeg) + + segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) + if rescueRtx { + // We do the last part of rule (4) of NextSeg here to update + // RescueRxt as until this point we don't know if we are going + // to use the rescue transmission. + s.fr.rescueRxt = s.fr.last + } else { + // RFC 6675, Step C.2 + // + // "If any of the data octets sent in (C.1) are below + // HighData, HighRxt MUST be set to the highest sequence + // number of the retransmitted segment unless NextSeg () + // rule (4) was invoked for this retransmission." + s.fr.highRxt = segEnd - 1 + } } return dataSent } +func (s *sender) sendZeroWindowProbe() { + ack, win := s.ep.rcv.getSendParams() + s.unackZeroWindowProbes++ + // Send a zero window probe with sequence number pointing to + // the last acknowledged byte. + s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win) + // Rearm the timer to continue probing. + s.resendTimer.enable(s.rto) +} + +func (s *sender) enableZeroWindowProbing() { + s.zeroWindowProbing = true + // We piggyback the probing on the retransmit timer with the + // current retranmission interval, as we may start probing while + // segment retransmissions. + if s.firstRetransmittedSegXmitTime.IsZero() { + s.firstRetransmittedSegXmitTime = time.Now() + } + s.resendTimer.enable(s.rto) +} + +func (s *sender) disableZeroWindowProbing() { + s.zeroWindowProbing = false + s.unackZeroWindowProbes = 0 + s.firstRetransmittedSegXmitTime = time.Time{} + s.resendTimer.disable() +} + // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { @@ -799,7 +1023,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto { + if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { if s.sndCwnd > InitialCwnd { s.sndCwnd = InitialCwnd } @@ -817,6 +1041,9 @@ func (s *sender) sendData() { limit = cwndLimit } if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Move writeNext along so that we don't try and scan data that + // has already been SACKED. + s.writeNext = seg.Next() continue } if sent := s.maybeSendSegment(seg, limit, end); !sent { @@ -834,6 +1061,13 @@ func (s *sender) sendData() { s.ep.disableKeepaliveTimer() } + // If the sender has advertized zero receive window and we have + // data to be sent out, start zero window probing to query the + // the remote for it's receive window size. + if s.writeNext != nil && s.sndWnd == 0 { + s.enableZeroWindowProbing() + } + // Enable the timer if we have pending data and it's not enabled yet. if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { s.resendTimer.enable(s.rto) @@ -855,6 +1089,8 @@ func (s *sender) enterFastRecovery() { s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding + s.fr.highRxt = s.sndUna + s.fr.rescueRxt = s.sndUna if s.ep.sackPermitted { s.state = SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() @@ -1040,21 +1276,21 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. -func (s *sender) handleRcvdSegment(seg *segment) { +func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. - if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { s.updateRTO(time.Now().Sub(s.rttMeasureTime)) s.rttMeasureSeqNum = s.sndNxt } // Update Timestamp if required. See RFC7323, section-4.3. - if s.ep.sendTSOk && seg.parsedOptions.TS { - s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS { + s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber) } // Insert SACKBlock information into our scoreboard. if s.ep.sackPermitted { - for _, sb := range seg.parsedOptions.SACKBlocks { + for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds // true: // * SACK block acks data after the ack number in the @@ -1067,22 +1303,40 @@ func (s *sender) handleRcvdSegment(seg *segment) { // NOTE: This check specifically excludes DSACK blocks // which have start/end before sndUna and are used to // indicate spurious retransmissions. - if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { s.ep.scoreboard.Insert(sb) - seg.hasNewSACKInfo = true + rcvdSeg.hasNewSACKInfo = true } } s.SetPipe() } // Count the duplicates and do the fast retransmit if needed. - rtx := s.checkDuplicateAck(seg) + rtx := s.checkDuplicateAck(rcvdSeg) // Stash away the current window size. - s.sndWnd = seg.window + s.sndWnd = rcvdSeg.window + + ack := rcvdSeg.ackNumber + + // Disable zero window probing if remote advertizes a non-zero receive + // window. This can be with an ACK to the zero window probe (where the + // acknumber refers to the already acknowledged byte) OR to any previously + // unacknowledged segment. + if s.zeroWindowProbing && rcvdSeg.window > 0 && + (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { + s.disableZeroWindowProbing() + } + + // On receiving the ACK for the zero window probe, account for it and + // skip trying to send any segment as we are still probing for + // receive window to become non-zero. + if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna { + s.unackZeroWindowProbes-- + return + } // Ignore ack if it doesn't acknowledge any new data. - ack := seg.ackNumber if (ack - 1).InRange(s.sndUna, s.sndNxt) { s.dupAckCount = 0 @@ -1094,15 +1348,15 @@ func (s *sender) handleRcvdSegment(seg *segment) { // averaged RTT measurement only if the segment acknowledges // some new data, i.e., only if it advances the left edge of // the send window. - if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { // TSVal/Ecr values sent by Netstack are at a millisecond // granularity. - elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond s.updateRTO(elapsed) } // When an ack is received we must rearm the timer. - // RFC 6298 5.2 + // RFC 6298 5.3 s.resendTimer.enable(s.rto) // Remove all acknowledged data from the write list. @@ -1111,6 +1365,9 @@ func (s *sender) handleRcvdSegment(seg *segment) { ackLeft := acked originalOutstanding := s.outstanding + s.rtt.Lock() + srtt := s.rtt.srtt + s.rtt.Unlock() for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1129,6 +1386,12 @@ func (s *sender) handleRcvdSegment(seg *segment) { if s.writeNext == seg { s.writeNext = seg.Next() } + + // Update the RACK fields if SACK is enabled. + if s.ep.sackPermitted { + s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset) + } + s.writeList.Remove(seg) // if SACK is enabled then Only reduce outstanding if @@ -1169,6 +1432,8 @@ func (s *sender) handleRcvdSegment(seg *segment) { // RFC 6298 Rule 5.3 if s.sndUna == s.sndNxt { s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() } } @@ -1182,14 +1447,14 @@ func (s *sender) handleRcvdSegment(seg *segment) { // that the window opened up, or the congestion window was inflated due // to a duplicate ack during fast recovery. This will also re-enable // the retransmit timer if needed. - if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo { + if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo { s.sendData() } } // sendSegment sends the specified segment. func (s *sender) sendSegment(seg *segment) *tcpip.Error { - if !seg.xmitTime.IsZero() { + if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() if s.sndCwnd < s.sndSsthresh { @@ -1197,7 +1462,24 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { } } seg.xmitTime = time.Now() - return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) + seg.xmitCount++ + err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) + + // Every time a packet containing data is sent (including a + // retransmission), if SACK is enabled and we are retransmitting data + // then use the conservative timer described in RFC6675 Section 6.0, + // otherwise follow the standard time described in RFC6298 Section 5.1. + if err != nil && seg.data.Size() != 0 { + if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted { + s.resendTimer.enable(s.rto) + } else { + if !s.resendTimer.enabled() { + s.resendTimer.enable(s.rto) + } + } + } + + return err } // sendSegmentFromView sends a new segment containing the given payload, flags @@ -1213,19 +1495,5 @@ func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq // Remember the max sent ack. s.maxSentAck = rcvNxt - // Every time a packet containing data is sent (including a - // retransmission), if SACK is enabled then use the conservative timer - // described in RFC6675 Section 4.0, otherwise follow the standard time - // described in RFC6298 Section 5.2. - if data.Size() != 0 { - if s.ep.sackPermitted { - s.resendTimer.enable(s.rto) - } else { - if !s.resendTimer.enabled() { - s.resendTimer.enable(s.rto) - } - } - } - return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 12eff8afc..8b20c3455 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) } + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 782d7b42c..b9993ce1a 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" ) func TestFastRecovery(t *testing.T) { @@ -40,7 +41,7 @@ func TestFastRecovery(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -49,7 +50,7 @@ func TestFastRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -86,16 +87,23 @@ func TestFastRecovery(t *testing.T) { // Receive the retransmitted packet. c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } + // Wait before checking metrics. + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) + } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want) + } + return nil } - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Now send 7 mode duplicate acks. Each of these should cause a window @@ -117,12 +125,18 @@ func TestFastRecovery(t *testing.T) { // Receive the retransmit due to partial ack. c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + // Wait before checking metrics. + metricPollFn = func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) + } + return nil } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Receive the 10 extra packets that should have been released due to @@ -192,7 +206,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -201,7 +215,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } expected := tcp.InitialCwnd @@ -234,7 +248,7 @@ func TestCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -243,7 +257,7 @@ func TestCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -338,7 +352,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { @@ -348,7 +362,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -447,7 +461,7 @@ func TestRetransmit(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -457,11 +471,11 @@ func TestRetransmit(t *testing.T) { // MTU size though. half := data[:len(data)/2] if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } half = data[len(data)/2:] if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -492,24 +506,33 @@ func TestRetransmit(t *testing.T) { rtxOffset := bytesRead - maxPayload*expected c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) - } + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want) + } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) - } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want) + } + + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + return nil } - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) + // Poll when checking metrics. + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Acknowledge half of the pending data. diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go new file mode 100644 index 000000000..e03f101e8 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" +) + +// TestRACKUpdate tests the RACK related fields are updated when an ACK is +// received on a SACK enabled connection. +func TestRACKUpdate(t *testing.T) { + const maxPayload = 10 + const tsOptionSize = 12 + const maxTCPOptionSize = 40 + + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + defer c.Cleanup() + + var xmitTime time.Time + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.RACKState is what we expect. + if state.Sender.RACKState.XmitTime.Before(xmitTime) { + t.Fatalf("RACK transmit time failed to update when an ACK is received") + } + + gotSeq := state.Sender.RACKState.EndSequence + wantSeq := state.Sender.SndNxt + if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { + t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) + } + + if state.Sender.RACKState.RTT == 0 { + t.Fatalf("RACK RTT failed to update when an ACK is received") + } + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + data := buffer.NewView(maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + xmitTime = time.Now() + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + c.SendAck(790, bytesRead) + time.Sleep(200 * time.Millisecond) +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index afea124ec..99521f0c1 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" ) // createConnectedWithSACKPermittedOption creates and connects c.ep with the @@ -46,7 +47,7 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err) + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) } } @@ -149,21 +150,22 @@ func TestSackPermittedAccept(t *testing.T) { {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + for _, tc := range testCases { t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } for _, sackEnabled := range []bool{false, true} { t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + + if tc.cookieEnabled { + // Set the SynRcvd threshold to + // zero to force a syn cookie + // based accept to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } setStackSACKPermitted(t, c, sackEnabled) rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) @@ -222,21 +224,23 @@ func TestSackDisabledAccept(t *testing.T) { {true, -1, 0xffff}, // When cookie is used window scaling is disabled. {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + for _, tc := range testCases { t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } for _, sackEnabled := range []bool{false, true} { t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + + if tc.cookieEnabled { + // Set the SynRcvd threshold to + // zero to force a syn cookie + // based accept to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + setStackSACKPermitted(t, c, sackEnabled) rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) @@ -387,7 +391,7 @@ func TestSACKRecovery(t *testing.T) { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -396,7 +400,7 @@ func TestSACKRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -436,21 +440,28 @@ func TestSACKRecovery(t *testing.T) { // Receive the retransmitted packet. c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want) + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil + } + + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause @@ -514,22 +525,28 @@ func TestSACKRecovery(t *testing.T) { bytesRead += maxPayload } - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } + metricPollFn = func() error { + // In SACK recovery only the first segment is fast retransmitted when + // entering recovery. + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) - } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want) + } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want) + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6d808328c..55ae09a2f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -34,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/pkg/waiter" ) @@ -55,7 +57,7 @@ func TestGiveUpConnect(t *testing.T) { var wq waiter.Queue ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Register for notification, then start connection attempt. @@ -64,7 +66,7 @@ func TestGiveUpConnect(t *testing.T) { defer wq.EventUnregister(&waitEntry) if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) } // Close the connection, wait for completion. @@ -73,7 +75,21 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { - t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) + t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted) + } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) } } @@ -86,7 +102,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) } } @@ -99,10 +115,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) + t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) } } @@ -113,20 +129,38 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { stats := c.Stack().Stats() ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } c.EP = ep want := stats.TCP.FailedConnectionAttempts.Value() + 1 if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute) } if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) + t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) + } +} + +func TestCloseWithoutConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + c.EP.Close() + + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) } } @@ -140,10 +174,10 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { - t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) + t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) } } @@ -154,16 +188,16 @@ func TestTCPResetsSentIncrement(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } want := stats.TCP.SegmentsSent.Value() + 1 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -194,8 +228,15 @@ func TestTCPResetsSentIncrement(t *testing.T) { c.SendPacket(nil, ackHeaders) c.GetPacket() - if got := stats.TCP.ResetsSent.Value(); got != want { - t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) + + metricPollFn := func() error { + if got := stats.TCP.ResetsSent.Value(); got != want { + return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } } @@ -206,17 +247,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -256,7 +298,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -264,6 +306,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -271,7 +320,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SeqNum(uint32(c.IRS+1)), checker.AckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -285,6 +333,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -296,8 +349,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) } func TestTCPResetsReceivedIncrement(t *testing.T) { @@ -320,7 +373,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { }) if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) } } @@ -344,7 +397,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { }) if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) } c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) } @@ -368,7 +421,7 @@ func TestNonBlockingClose(t *testing.T) { t0 := time.Now() ep.Close() if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) + t.Fatalf("Took too long to close: %s", diff) } } @@ -376,6 +429,13 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -396,12 +456,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -413,15 +485,219 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence number + // of the FIN itself. + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) break } } +// TestCurrentConnectedIncrement tests increment of the current +// established and connected counters. +func TestCurrentConnectedIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) + } + gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() + if gotConnected != 1 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) + } + + ep.Close() + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) + } + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Wait for a little more than the TIME-WAIT duration for the socket to + // transition to CLOSED state. + time.Sleep(1200 * time.Millisecond) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. +func TestClosingWithEnqueuedSegments(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("unexpected endpoint state: want %d, got %d", want, got) + } + + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Get the ACK for the FIN we sent. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Give the stack a few ms to transition the endpoint out of ESTABLISHED + // state. + time.Sleep(10 * time.Millisecond) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("unexpected endpoint state: want %d, got %d", want, got) + } + + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() + + // Get the FIN + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + // Pause the endpoint`s protocolMainLoop. + ep.(interface{ StopWork() }).StopWork() + + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 792, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Unpause endpoint`s protocolMainLoop. + ep.(interface{ ResumeWork() }).ResumeWork() + + // Wait for the protocolMainLoop to resume and update state. + time.Sleep(10 * time.Millisecond) + + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + + // Check if the endpoint was moved to CLOSED and netstack a reset in + // response to the ACK packet that we sent after last-ACK. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), + ), + ) +} + func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -433,7 +709,7 @@ func TestSimpleReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -456,7 +732,7 @@ func TestSimpleReceive(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -474,6 +750,488 @@ func TestSimpleReceive(t *testing.T) { ) } +// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when +// creating a new active TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnect(t *testing.T) { + const mtu = 5000 + + ips := []struct { + name string + createEP func(*context.Context) + connectAddr tcpip.Address + checker func(*testing.T, *context.Context, uint16, int) + maxMSS uint16 + }{ + { + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + connectAddr: context.TestAddr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, + }, + { + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(true) + }, + connectAddr: context.TestV6Addr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, + }, + } + + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + ip.createEP(c) + + // Set the MSS socket option. + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} + if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("Connect(%+v): %s", connectAddr, err) + } + + // Receive SYN packet with our user supplied MSS. + ip.checker(t, c, test.expMSS, ws) + }) + } + }) + } +} + +// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used +// when completing the handshake for a new TCP connection from a TCP +// listening socket. It should be present in the sent TCP SYN-ACK segment. +func TestUserSuppliedMSSOnListenAccept(t *testing.T) { + const ( + nonSynCookieAccepts = 2 + totalAccepts = 4 + mtu = 5000 + ) + + ips := []struct { + name string + createEP func(*context.Context) + sendPkt func(*context.Context, *context.Headers) + checker func(*testing.T, *context.Context, uint16, uint16) + maxMSS uint16 + }{ + { + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendPacket(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, + }, + { + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(false) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendV6Packet(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, + }, + } + + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + ip.createEP(c) + + // Set the SynRcvd threshold to force a syn cookie based accept to happen. + opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err) + } + + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } + + bindAddr := tcpip.FullAddress{Port: context.StackPort} + if err := c.EP.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s:", bindAddr, err) + } + + if err := c.EP.Listen(totalAccepts); err != nil { + t.Fatalf("Listen(%d): %s:", totalAccepts, err) + } + + // The first nonSynCookieAccepts packets sent will trigger a gorooutine + // based accept. The rest will trigger a cookie based accept. + for i := 0; i < totalAccepts; i++ { + // Send a SYN requests. + iss := seqnum.Value(i) + srcPort := context.TestPort + uint16(i) + ip.sendPkt(c, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + ip.checker(t, c, srcPort, test.expMSS) + } + }) + } + }) + } +} +func TestSendRstOnListenerRxSynAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxSynAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv4. +func TestTCPAckBeforeAcceptV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv6. +func TestTCPAckBeforeAcceptV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestSendRstOnListenerRxAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true /* v6Only */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestListenShutdown tests for the listening endpoint replying with RST +// on read shutdown. +func TestListenShutdown(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) + + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) +} + +// TestListenCloseWhileConnect tests for the listening endpoint to +// drain the accept-queue when closed. This should reset all of the +// pending connections that are waiting to be accepted. +func TestListenCloseWhileConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventIn) + defer c.WQ.EventUnregister(&waitEntry) + + executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + // Wait for the new endpoint created because of handshake to be delivered + // to the listening endpoint's accept queue. + <-notifyCh + + // Close the listening endpoint. + c.EP.Close() + + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -485,17 +1243,17 @@ func TestTOSV4(t *testing.T) { c.EP = ep const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) } - var v tcpip.IPv4TOSOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Errorf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) } testV4Connect(t, c, checker.TOS(tos, 0)) @@ -533,17 +1291,17 @@ func TestTrafficClassV6(t *testing.T) { c.CreateV6Endpoint(false) const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { + t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) } - var v tcpip.IPv6TrafficClassOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) } // Test the connection request. @@ -578,12 +1336,12 @@ func TestTrafficClassV6(t *testing.T) { func TestConnectBindToDevice(t *testing.T) { for _, test := range []struct { name string - device string + device tcpip.NICID want tcp.EndpointState }{ - {"RightDevice", "nic1", tcp.StateEstablished}, - {"WrongDevice", "nic2", tcp.StateSynSent}, - {"AnyDevice", "", tcp.StateEstablished}, + {"RightDevice", 1, tcp.StateEstablished}, + {"WrongDevice", 2, tcp.StateSynSent}, + {"AnyDevice", 0, tcp.StateEstablished}, } { t.Run(test.name, func(t *testing.T) { c := context.New(t, defaultMTU) @@ -598,7 +1356,7 @@ func TestConnectBindToDevice(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("unexpected return value from Connect: %s", err) } // Receive SYN packet. @@ -610,7 +1368,7 @@ func TestConnectBindToDevice(t *testing.T) { ), ) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) } tcpHdr := header.TCP(header.IPv4(b).Payload()) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) @@ -629,7 +1387,95 @@ func TestConnectBindToDevice(t *testing.T) { c.GetPacket() if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) + } + }) + } +} + +func TestSynSent(t *testing.T) { + for _, test := range []struct { + name string + reset bool + }{ + {"RstOnSynSent", true}, + {"CloseOnSynSent", false}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with the + // handshake process. + c.Create(-1) + + // Start connection attempt. + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { + t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + if test.reset { + // Send a packet with a proper ACK and a RST flag to cause the socket + // to error and close out. + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagRst | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + } else { + c.EP.Close() + } + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for packet to arrive") + } + + if test.reset { + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused) + } + } else { + if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted) + } + } + + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } + + // Due to the RST the endpoint should be in an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) } }) } @@ -646,7 +1492,7 @@ func TestOutOfOrderReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send second half of data first, with seqnum 3 ahead of expected. @@ -673,7 +1519,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send the first 3 bytes now. @@ -700,7 +1546,7 @@ func TestOutOfOrderReceive(t *testing.T) { } continue } - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } read = append(read, v...) @@ -730,7 +1576,7 @@ func TestOutOfOrderFlood(t *testing.T) { c.CreateConnected(789, 30000, 10) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send 100 packets before the actual one that is expected. @@ -807,7 +1653,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -850,7 +1696,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // This final ACK should be ignored because an ACK on a reset doesn't mean @@ -876,7 +1722,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -918,7 +1764,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { )) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Cause a RST to be generated by closing the read end now since we have @@ -930,12 +1776,14 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence + // number of the FIN itself. + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // The ACK to the FIN should now be rejected since the connection has been @@ -957,19 +1805,19 @@ func TestShutdownRead(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) } var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { - t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) } } @@ -985,7 +1833,7 @@ func TestFullWindowReceive(t *testing.T) { _, _, err := c.EP.Read(nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } // Fill up the window. @@ -1020,7 +1868,7 @@ func TestFullWindowReceive(t *testing.T) { // Receive data and check it. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -1029,7 +1877,7 @@ func TestFullWindowReceive(t *testing.T) { var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { - t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) } // Check that we get an ACK for the newly non-zero window. @@ -1052,7 +1900,7 @@ func TestNoWindowShrinking(t *testing.T) { c.CreateConnected(789, 30000, 10) if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) } we, ch := waiter.NewChannelEntry(nil) @@ -1060,7 +1908,7 @@ func TestNoWindowShrinking(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send 3 bytes, check that the peer acknowledges them. @@ -1124,7 +1972,7 @@ func TestNoWindowShrinking(t *testing.T) { for len(read) < len(data) { v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } read = append(read, v...) @@ -1158,7 +2006,7 @@ func TestSimpleSend(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received. @@ -1192,7 +2040,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 0, -1 /* epRcvBuf */) + c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1200,11 +2048,20 @@ func TestZeroWindowSend(t *testing.T) { _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) if err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } - // Since the window is currently zero, check that no packet is received. - c.CheckNoPacket("Packet received when window is zero") + // Check if we got a zero-window probe. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) // Open up the window. Data should be received now. c.SendPacket(nil, &context.Headers{ @@ -1217,7 +2074,7 @@ func TestZeroWindowSend(t *testing.T) { }) // Check that data is received. - b := c.GetPacket() + b = c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( @@ -1259,7 +2116,7 @@ func TestScaledWindowConnect(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received, and that advertised window is 0xbfff, @@ -1291,7 +2148,7 @@ func TestNonScaledWindowConnect(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received, and that advertised window is 0xffff, @@ -1319,21 +2176,21 @@ func TestScaledWindowAccept(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake. @@ -1351,7 +2208,7 @@ func TestScaledWindowAccept(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -1364,7 +2221,7 @@ func TestScaledWindowAccept(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received, and that advertised window is 0xbfff, @@ -1392,21 +2249,21 @@ func TestNonScaledWindowAccept(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN @@ -1425,7 +2282,7 @@ func TestNonScaledWindowAccept(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -1438,7 +2295,7 @@ func TestNonScaledWindowAccept(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received, and that advertised window is 0xffff, @@ -1522,10 +2379,14 @@ func TestZeroScaledWindowReceive(t *testing.T) { ) } - // Read some data. An ack should be sent in response to that. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) + // Read at least 1MSS of data. An ack should be sent in response to that. + sz := 0 + for sz < defaultMTU { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + sz += len(v) } checker.IPv4(t, c.GetPacket(), @@ -1534,7 +2395,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+1), checker.AckNum(uint32(790+sent)), - checker.Window(uint16(len(v)>>ws)), + checker.Window(uint16(sz>>ws)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1558,10 +2419,10 @@ func TestSegmentMerging(t *testing.T) { { "cork", func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(1)) + ep.SetSockOptBool(tcpip.CorkOption, true) }, func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(0)) + ep.SetSockOptBool(tcpip.CorkOption, false) }, }, } @@ -1573,20 +2434,50 @@ func TestSegmentMerging(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - // Prevent the endpoint from processing packets. - test.stop(c.EP) + // Send tcp.InitialCwnd number of segments to fill up + // InitialWindow but don't ACK. That should prevent + // anymore packets from going out. + for i := 0; i < tcp.InitialCwnd; i++ { + view := buffer.NewViewFromBytes([]byte{0}) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } + // Now send the segments that should get merged as the congestion + // window is full and we won't be able to send any more packets. var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } - // Let the endpoint process the segments that we just sent. - test.resume(c.EP) + // Check that we get tcp.InitialCwnd packets. + for i := 0; i < tcp.InitialCwnd; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize+1), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+uint32(i)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. + RcvWnd: 30000, + }) // Check that data is received. b := c.GetPacket() @@ -1594,7 +2485,7 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+11), checker.AckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), @@ -1610,7 +2501,7 @@ func TestSegmentMerging(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), + AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), RcvWnd: 30000, }) }) @@ -1623,14 +2514,14 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -1671,13 +2562,13 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -1704,7 +2595,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOptInt(tcpip.DelayOption, 0) + c.EP.SetSockOptBool(tcpip.DelayOption, false) // Check that data is received. second := c.GetPacket() @@ -1741,8 +2632,8 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, } for _, test := range tests { @@ -1761,7 +2652,7 @@ func TestMSSNotDelayed(t *testing.T) { for i, data := range allData { view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -1812,7 +2703,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received in chunks. @@ -1880,15 +2771,15 @@ func TestSetTTL(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } - if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("unexpected return value from Connect: %s", err) } // Receive SYN packet. @@ -1920,7 +2811,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() @@ -1928,15 +2819,15 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake. @@ -1954,7 +2845,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -1974,26 +2865,24 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } // Create EP and start listening. wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake. @@ -2011,7 +2900,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -2045,7 +2934,7 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) { select { case err := <-ch: if err != nil { - t.Fatalf("Error creating endpoint: %v", err) + t.Fatalf("Error creating endpoint: %s", err) } case <-time.After(2 * time.Second): t.Fatalf("Timed out waiting for connection") @@ -2064,7 +2953,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Set the buffer size to a deterministic size so that we can check the @@ -2072,7 +2961,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { const rcvBufferSize = 0x20000 const wndScale = 2 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } // Start connection attempt. @@ -2081,7 +2970,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { defer c.WQ.EventUnregister(&we) if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) } // Receive SYN packet. @@ -2135,7 +3024,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { select { case <-ch: if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("GetSockOpt failed: %s", err) } case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for connection") @@ -2150,22 +3039,22 @@ func TestCloseListener(t *testing.T) { var wq waiter.Queue ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Close the listener and measure how long it takes. t0 := time.Now() ep.Close() if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) + t.Fatalf("Took too long to close: %s", diff) } } @@ -2201,16 +3090,26 @@ loop: case tcpip.ErrConnectionReset: break loop default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) } } // Expect the state to be StateError and subsequent Reads to fail with HardError. if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) } if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { @@ -2234,7 +3133,162 @@ func TestSendOnResetConnection(t *testing.T) { // Try to write. view := buffer.NewView(10) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) + } +} + +// TestMaxRetransmitsTimeout tests if the connection is timed out after +// a segment has been retransmitted MaxRetries times. +func TestMaxRetransmitsTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const numRetries = 2 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { + t.Fatalf("could not set protocol option MaxRetries.\n") + } + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Expect first transmit and MaxRetries retransmits. + for i := 0; i < numRetries+1; i++ { + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), + ), + ) + } + // Wait for the connection to timeout after MaxRetries retransmits. + initRTO := 1 * time.Second + select { + case <-notifyCh: + case <-time.After((2 << numRetries) * initRTO): + t.Fatalf("connection still alive after maximum retransmits.\n") + } + + // Send an ACK and expect a RST as the connection would have been closed. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +// TestMaxRTO tests if the retransmit interval caps to MaxRTO. +func TestMaxRTO(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + rto := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + } + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + const numRetransmits = 2 + for i := 0; i < numRetransmits; i++ { + start := time.Now() + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { + t.Errorf("Retransmit interval not capped to MaxRTO.\n") + } + } +} + +// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is +// unique on retransmits. +func TestRetransmitIPv4IDUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + id := header.IPv4(pkt).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} + } + }) } } @@ -2246,7 +3300,7 @@ func TestFinImmediately(t *testing.T) { // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2289,7 +3343,7 @@ func TestFinRetransmit(t *testing.T) { // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2344,7 +3398,7 @@ func TestFinWithNoPendingData(t *testing.T) { // Write something out, and have it acknowledged. view := buffer.NewView(10) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -2370,7 +3424,7 @@ func TestFinWithNoPendingData(t *testing.T) { // Shutdown, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2417,7 +3471,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { view := buffer.NewView(10) for i := tcp.InitialCwnd; i > 0; i-- { if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } } @@ -2439,7 +3493,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // because the congestion window doesn't allow it. Wait until a // retransmit is received. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2503,7 +3557,7 @@ func TestFinWithPendingData(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -2529,7 +3583,7 @@ func TestFinWithPendingData(t *testing.T) { // Write new data, but don't acknowledge it. if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2545,7 +3599,7 @@ func TestFinWithPendingData(t *testing.T) { // Shutdown the connection, check that we do get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2590,7 +3644,7 @@ func TestFinWithPartialAck(t *testing.T) { // FIN from the test side. view := buffer.NewView(10) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -2627,7 +3681,7 @@ func TestFinWithPartialAck(t *testing.T) { // Write new data, but don't acknowledge it. if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2643,7 +3697,7 @@ func TestFinWithPartialAck(t *testing.T) { // Shutdown the connection, check that we do get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2689,20 +3743,20 @@ func TestUpdateListenBacklog(t *testing.T) { var wq waiter.Queue ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Update the backlog with another Listen() on the same endpoint. if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %v", err) + t.Fatalf("Listen failed to update backlog: %s", err) } ep.Close() @@ -2734,7 +3788,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { // Send some data. Check that it's capped by the window size. view := buffer.NewView(65535) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that only data that fits in the scaled window is sent. @@ -2780,18 +3834,18 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { }) if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { - t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) + t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) } // Ensure there were no errors during handshake. If these stats have // incremented, then the connection should not have been established. if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) } if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) + t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0) } } @@ -2809,16 +3863,16 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) } } @@ -2836,7 +3890,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 @@ -2905,6 +3959,13 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) @@ -2912,12 +3973,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2931,7 +3992,7 @@ func TestReadAfterClosedState(t *testing.T) { ) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Send some data and acknowledge the FIN. @@ -2955,13 +4016,12 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Wait for receive to be notified. @@ -2975,7 +4035,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -2986,7 +4046,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -2996,11 +4056,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -3014,66 +4074,84 @@ func TestReusePort(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } c.EP.Close() c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } c.EP.Close() // Second case, an endpoint that was bound and is connecting.. c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) } c.EP.Close() c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } c.EP.Close() // Third case, an endpoint that was bound and is listening. c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } c.EP.Close() c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } } @@ -3082,11 +4160,11 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("GetSockOpt failed: %s", err) } if int(s) != v { - t.Fatalf("got receive buffer size = %v, want = %v", s, v) + t.Fatalf("got receive buffer size = %d, want = %d", s, v) } } @@ -3095,11 +4173,11 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("GetSockOpt failed: %s", err) } if int(s) != v { - t.Fatalf("got send buffer size = %v, want = %v", s, v) + t.Fatalf("got send buffer size = %d, want = %d", s, v) } } @@ -3112,7 +4190,7 @@ func TestDefaultBufferSizes(t *testing.T) { // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer func() { if ep != nil { @@ -3124,28 +4202,34 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } ep.Close() ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %v", err) } ep.Close() ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) @@ -3161,41 +4245,41 @@ func TestMinMaxBufferSizes(t *testing.T) { // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Set values below the min. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) @@ -3208,50 +4292,45 @@ func TestBindToDeviceOption(t *testing.T) { ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) + if err := s.CreateNIC(321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %s", err) } - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %s, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } @@ -3314,12 +4393,12 @@ func TestSelfConnect(t *testing.T) { var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Register for notification, then start connection attempt. @@ -3328,12 +4407,12 @@ func TestSelfConnect(t *testing.T) { defer wq.EventUnregister(&waitEntry) if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) } <-notifyCh if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("Connect failed: %v", err) + t.Fatalf("Connect failed: %s", err) } // Write something. @@ -3341,7 +4420,7 @@ func TestSelfConnect(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Read back what was written. @@ -3350,12 +4429,12 @@ func TestSelfConnect(t *testing.T) { rd, _, err := ep.Read(nil) if err != nil { if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } <-notifyCh rd, _, err = ep.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } } @@ -3439,18 +4518,18 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { } ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } eps = append(eps, ep) switch network { case "ipv4": case "ipv6": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) } case "dual": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { + t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) } default: t.Fatalf("unknown network: '%s'", network) @@ -3490,7 +4569,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %v", i, err) + t.Fatalf("Bind(%d) failed: %s", i, err) } } want := tcpip.ErrConnectStarted @@ -3498,7 +4577,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { want = tcpip.ErrNoPortAvailable } if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want) + t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) } }) } @@ -3532,7 +4611,7 @@ func TestPathMTUDiscovery(t *testing.T) { } if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { @@ -3635,7 +4714,7 @@ func TestStackSetCongestionControl(t *testing.T) { var oldCC tcpip.CongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) + t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) } if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { @@ -3722,12 +4801,12 @@ func TestEndpointSetCongestionControl(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } var oldCC tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) + t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err) } if connected { @@ -3735,12 +4814,12 @@ func TestEndpointSetCongestionControl(t *testing.T) { } if err := c.EP.SetSockOpt(tc.cc); err != tc.err { - t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) + t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) + t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err) } got, want := cc, oldCC @@ -3763,7 +4842,7 @@ func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() opt := tcpip.CongestionControlOption("cubic") if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err) } } @@ -3773,10 +4852,11 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + const keepAliveInterval = 3 * time.Second + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -3804,14 +4884,14 @@ func TestKeepalive(t *testing.T) { // Check that the connection is still alive. if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. view := buffer.NewView(3) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -3864,18 +4944,45 @@ func TestKeepalive(t *testing.T) { ) } + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + keepAliveInterval/2) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), ) + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) + } + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) } } @@ -3890,7 +4997,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki RcvWnd: 30000, }) - // Receive the SYN-ACK reply.w + // Receive the SYN-ACK reply. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss = seqnum.Value(tcp.SequenceNumber()) @@ -3923,6 +5030,50 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki return irs, iss } +func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + // Send a SYN request. + irs = seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss = seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + + if synCookieInUse { + // When cookies are in use window scaling is disabled. + tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ + WS: -1, + MSS: c.MSSWithoutOptionsV6(), + })) + } + + checker.IPv6(t, b, checker.TCP(tcpCheckers...)) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + return irs, iss +} + // TestListenBacklogFull tests that netstack does not complete handshakes if the // listen backlog for the endpoint is full. func TestListenBacklogFull(t *testing.T) { @@ -3933,19 +5084,19 @@ func TestListenBacklogFull(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Bind to wildcard. if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Test acceptance. // Start listening. listenBacklog := 2 if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } for i := 0; i < listenBacklog; i++ { @@ -3978,7 +5129,7 @@ func TestListenBacklogFull(t *testing.T) { case <-ch: _, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4007,7 +5158,7 @@ func TestListenBacklogFull(t *testing.T) { case <-ch: newEP, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4021,7 +5172,215 @@ func TestListenBacklogFull(t *testing.T) { b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) + t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) + } +} + +// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a +// non unicast IPv4 address are not accepted. +func TestListenNoAcceptNonUnicastV4(t *testing.T) { + multicastAddr := tcpip.Address("\xe0\x00\x01\x02") + otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv4Any, + context.StackAddr, + }, + { + "SourceBroadcast", + header.IPv4Broadcast, + context.StackAddr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackAddr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackAddr, + }, + { + "DestUnspecified", + context.TestAddr, + header.IPv4Any, + }, + { + "DestBroadcast", + context.TestAddr, + header.IPv4Broadcast, + }, + { + "DestOurMulticast", + context.TestAddr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestAddr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + test := test // capture range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestAddr, context.StackAddr) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + +// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a +// non unicast IPv6 address are not accepted. +func TestListenNoAcceptNonUnicastV6(t *testing.T) { + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv6Any, + context.StackV6Addr, + }, + { + "SourceAllNodes", + header.IPv6AllNodesMulticastAddress, + context.StackV6Addr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackV6Addr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackV6Addr, + }, + { + "DestUnspecified", + context.TestV6Addr, + header.IPv6Any, + }, + { + "DestAllNodes", + context.TestV6Addr, + header.IPv6AllNodesMulticastAddress, + }, + { + "DestOurMulticast", + context.TestV6Addr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestV6Addr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + test := test // capture range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestV6Addr, context.StackV6Addr) + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) } } @@ -4033,19 +5392,19 @@ func TestListenSynRcvdQueueFull(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Bind to wildcard. if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Test acceptance. // Start listening. listenBacklog := 1 if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send two SYN's the first one should get a SYN-ACK, the @@ -4056,7 +5415,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { SrcPort: context.TestPort, DstPort: context.StackPort, Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), + SeqNum: irs, RcvWnd: 30000, }) @@ -4111,7 +5470,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { case <-ch: newEP, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4125,30 +5484,28 @@ func TestListenSynRcvdQueueFull(t *testing.T) { pkt := c.GetPacket() tcp = header.TCP(header.IPv4(pkt).Payload()) if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) + t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } } func TestListenBacklogFullSynCookieInUse(t *testing.T) { - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 1 - c := context.New(t, defaultMTU) defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + } + // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Bind to wildcard. if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -4156,7 +5513,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { listenBacklog := 1 portOffset := uint16(0) if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } executeHandshake(t, c, context.TestPort+portOffset, false) @@ -4167,7 +5524,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -4188,7 +5546,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { case <-ch: _, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4207,26 +5565,145 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } +func TestSynRcvdBadSeqNumber(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcpHdr.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now send a packet with an out-of-window sequence number + largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: largeSeqnum, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Should receive an ACK with the expected SEQ number + b = c.GetPacket() + tcpCheckers = []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.AckNum(uint32(irs) + 1), + checker.SeqNum(uint32(iss + 1)), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now that the socket replied appropriately with the ACK, + // complete the connection to test that the large SEQ num + // did not change the state from SYN-RCVD. + + // Send ACK to move to ESTABLISHED state. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + newEP, _, err := c.EP.Accept() + + if err != nil && err != tcpip.ErrWouldBlock { + t.Fatalf("Accept failed: %s", err) + } + + if err == tcpip.ErrWouldBlock { + // Try to accept the connections in the backlog. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + + if err != nil { + t.Fatalf("Write failed: %s", err) + } + + pkt := c.GetPacket() + tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + if string(tcpHdr.Payload()) != data { + t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) + } +} + func TestPassiveConnectionAttemptIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } c.EP = ep if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } stats := c.Stack().Stats() @@ -4247,7 +5724,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { case <-ch: _, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4256,7 +5733,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { } if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) } } @@ -4267,14 +5744,14 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { stats := c.Stack().Stats() ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } c.EP = ep if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } srcPort := uint16(context.TestPort) @@ -4299,10 +5776,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { time.Sleep(50 * time.Millisecond) if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) + t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) } we, ch := waiter.NewChannelEntry(nil) @@ -4317,7 +5794,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { case <-ch: _, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4332,29 +5809,28 @@ func TestEndpointBindListenAcceptState(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - // Expect InvalidEndpointState errors on a read at this point. - if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { + t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected) } - if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { - t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { + t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) @@ -4371,7 +5847,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { case <-ch: aep, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4379,22 +5855,25 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } } if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { + t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected) } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } ep.Close() // Give worker goroutines time to receive the close notification. time.Sleep(1 * time.Second) if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Accepted endpoint remains open when the listen endpoint is closed. if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } } @@ -4414,13 +5893,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Enable auto-tuning. if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Change the expected window scale to match the value needed for the // maximum buffer size defined above. @@ -4464,6 +5943,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { rawEP.SendPacketWithTS(b[start:start+mss], tsVal) packetsSent++ } + // Resume the worker so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() @@ -4512,7 +5992,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { return } if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd) + t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) } }, )) @@ -4531,16 +6011,16 @@ func TestReceiveBufferAutoTuning(t *testing.T) { stk := c.Stack() // Set lower limits for auto-tuning tests. This is required because the // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. + // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Enable auto-tuning. if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Change the expected window scale to match the value needed for the // maximum buffer size used by stack. @@ -4586,6 +6066,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalSent += mss packetsSent++ } + // Resume it so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() @@ -4618,7 +6099,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Invoke the moderation API. This is required for auto-tuning // to happen. This method is normally expected to be invoked // from a higher layer than tcpip.Endpoint. So we simulate - // copying to user-space by invoking it explicitly here. + // copying to userspace by invoking it explicitly here. c.EP.ModerateRecvBuf(totalCopied) // Now send a keep-alive packet to trigger an ACK so that we can @@ -4668,3 +6149,1300 @@ func TestReceiveBufferAutoTuning(t *testing.T) { payloadSize *= 2 } } + +func TestDelayEnabled(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + checkDelayOption(t, c, false, false) // Delay is disabled by default. + + for _, v := range []struct { + delayEnabled tcp.DelayEnabled + wantDelayOption bool + }{ + {delayEnabled: false, wantDelayOption: false}, + {delayEnabled: true, wantDelayOption: true}, + } { + c := context.New(t, defaultMTU) + defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err) + } + checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + } +} + +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { + t.Helper() + + var gotDelayEnabled tcp.DelayEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { + t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) + } + if gotDelayEnabled != wantDelayEnabled { + t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) + } + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) + if err != nil { + t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) + } + gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) + if err != nil { + t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) + } + if gotDelayOption != wantDelayOption { + t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) + } +} + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 125 * time.Second, 120 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } +} + +func TestTCPCloseWithData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: 30000, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now trigger a passive close by sending a FIN. + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + RcvWnd: 30000, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now write a few bytes and then close the endpoint. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b = c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } + + c.EP.Close() + // Check the FIN. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.AckNum(uint32(iss+2)), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + // First send a partial ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now send a full ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now ACK the FIN. + ackHeaders.AckNum++ + c.SendPacket(nil, ackHeaders) + + // Now send an ACK and we should get a RST back as the endpoint should + // be in CLOSED state. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Check the RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + // Ensure that on the next retransmit timer fire, the user timeout has + // expired. + initRTO := 1 * time.Second + userTimeout := initRTO / 2 + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for the retransmit timer to be fired and the user timeout to cause + // close of the connection. + select { + case <-notifyCh: + case <-time.After(2 * initRTO): + t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) + } + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 3 * time.Second + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + + // Set userTimeout to be the duration to be 1 keepalive + // probes. Which means that after the first probe is sent + // the second one should cause the connection to be + // closed due to userTimeout being hit. + userTimeout := 1 * keepAliveInterval + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + // Now receive 1 keepalives, but don't ACK it. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + keepAliveInterval/2) + + // The connection should be closed with a timeout. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +func TestIncreaseWindowOnReceive(t *testing.T) { + // This test ensures that the endpoint sends an ack, + // after recv() when the window grows to more than 1 MSS. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const rcvBuf = 65535 * 10 + c.CreateConnected(789, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, defaultMTU/2) + lastWnd := uint16(0) + + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + + lastWnd = uint16(remain) + if remain > 0xffff { + lastWnd = 0xffff + } + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(lastWnd), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + if lastWnd == 0xffff || lastWnd == 0 { + t.Fatalf("expected small, non-zero window: %d", lastWnd) + } + + // We now have < 1 MSS in the buffer space. Read the data! An + // ack should be sent in response to that. The window was not + // zero, but it grew to larger than MSS. + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("Read failed: %s", err) + } + + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("Read failed: %s", err) + } + + // After reading two packets, we surely crossed MSS. See the ack: + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(0xffff)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestIncreaseWindowOnBufferResize(t *testing.T) { + // This test ensures that the endpoint sends an ack, + // after available recv buffer grows to more than 1 MSS. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const rcvBuf = 65535 * 10 + c.CreateConnected(789, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, defaultMTU/2) + lastWnd := uint16(0) + + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + + lastWnd = uint16(remain) + if remain > 0xffff { + lastWnd = 0xffff + } + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(lastWnd), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + if lastWnd == 0xffff || lastWnd == 0 { + t.Fatalf("expected small, non-zero window: %d", lastWnd) + } + + // Increasing the buffer from should generate an ACK, + // since window grew from small value to larger equal MSS + c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) + + // After reading two packets, we surely crossed MSS. See the ack: + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(0xffff)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestTCPDeferAccept(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + } + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give a bit of time for the socket to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestTCPDeferAcceptTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + } + + // Sleep for a little of the tcpDeferAccept timeout. + time.Sleep(tcpDeferAccept + 100*time.Millisecond) + + // On timeout expiry we should get a SYN-ACK retransmission. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give sometime for the endpoint to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestResetDuringClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + iss := seqnum.Value(789) + c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) + // Send some data to make sure there is some unread + // data to trigger a reset on c.Close. + irs := c.IRS + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(1), + AckNum: irs.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(irs.Add(1))), + checker.AckNum(uint32(iss.Add(5))))) + + // Close in a separate goroutine so that we can trigger + // a race with the RST we send below. This should not + // panic due to the route being released depeding on + // whether Close() sends an active RST or the RST sent + // below is processed by the worker first. + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(5), + AckNum: c.IRS.Add(5), + RcvWnd: 30000, + Flags: header.TCPFlagRst, + }) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.EP.Close() + }() + + wg.Wait() +} + +func TestStackTimeWaitReuse(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + s := c.Stack() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) + } + if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { + t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + } +} + +func TestSetStackTimeWaitReuse(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + s := c.Stack() + testCases := []struct { + v int + err *tcpip.Error + }{ + {int(tcpip.TCPTimeWaitReuseDisabled), nil}, + {int(tcpip.TCPTimeWaitReuseGlobal), nil}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue}, + {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue}, + } + + for _, tc := range testCases { + err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v)) + if got, want := err, tc.err; got != want { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err) + } + if tc.err != nil { + continue + } + + var twReuse tcpip.TCPTimeWaitReuseOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) + } + + if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { + t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + } + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index a641e953d..8edbff964 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -127,16 +127,14 @@ func TestTimeStampDisabledConnect(t *testing.T) { } func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + c := context.New(t, defaultMTU) + defer c.Cleanup() if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } } - c := context.New(t, defaultMTU) - defer c.Cleanup() t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) tsVal := rand.Uint32() @@ -148,7 +146,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Unexpected error from Write: %s", err) } // Check that data is received and that the timestamp option TSEcr field @@ -190,17 +188,15 @@ func TestTimeStampEnabledAccept(t *testing.T) { } func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - c := context.New(t, defaultMTU) defer c.Cleanup() + if cookieEnabled { + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) @@ -211,7 +207,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Unexpected error from Write: %s", err) } // Check that data is received and that the timestamp option is disabled diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index 19b0d31c5..ce6a2c31d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,9 +6,8 @@ go_library( name = "context", testonly = 1, srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ - "//:sandbox", + "//visibility:public", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ef823e4ae..b6031354e 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -18,6 +18,7 @@ package context import ( "bytes" + "context" "testing" "time" @@ -142,13 +143,22 @@ func New(t *testing.T, mtu uint32) *Context { TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, }) + const sendBufferSize = 1 << 20 // 1 MiB + const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + + // Increase minimum RTO in tests to avoid test flakes due to early + // retransmit in case the test executors are overloaded and cause timers + // to fire earlier than expected. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { + t.Fatalf("failed to set stack-wide minRTO: %s", err) } // Some of the congestion control tests send up to 640 packets, we so @@ -158,15 +168,17 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) if testing.Verbose() { wep2 = sniffer.New(channel.New(1000, mtu, "")) } - if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts2 := stack.NICOptions{Name: "nic2"} + if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { @@ -192,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context { t: t, s: s, linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), + WindowScale: uint8(tcp.FindWndScale(recvBufferSize)), } } @@ -201,6 +213,7 @@ func (c *Context) Cleanup() { if c.EP != nil { c.EP.Close() } + c.Stack().Close() } // Stack returns a reference to the stack in the Context. @@ -213,11 +226,10 @@ func (c *Context) Stack() *stack.Stack { func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { c.t.Helper() - select { - case <-c.linkEP.C: + ctx, cancel := context.WithTimeout(context.Background(), wait) + defer cancel() + if _, ok := c.linkEP.ReadContext(ctx); ok { c.t.Fatal(errMsg) - - case <-time.After(wait): } } @@ -231,27 +243,29 @@ func (c *Context) CheckNoPacket(errMsg string) { // addresses. It will fail with an error if no packet is received for // 2 seconds. func (c *Context) GetPacket() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + c.t.Helper() - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + c.t.Fatalf("Packet wasn't written out") + return nil + } - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) } - return nil + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // GetPacketNonBlocking reads a packet from the link layer endpoint @@ -259,24 +273,26 @@ func (c *Context) GetPacket() []byte { // and destination address. If no packet is available it will return // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + c.t.Helper() - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - default: + p, ok := c.linkEP.Read() + if !ok { return nil } + + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { +func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) if len(buf) > maxTotalSize { @@ -302,11 +318,20 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // BuildSegment builds a TCP segment based on the given Headers and payload. func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { + return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) +} + +// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, +// payload and source and destination IPv4 addresses. +func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -319,8 +344,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, + SrcAddr: src, + DstAddr: dst, }) ip.SetChecksum(^ip.CalculateChecksum()) @@ -337,7 +362,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -350,13 +375,29 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: s, + }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: c.BuildSegment(payload, h), + }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) +} + +// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload +// & TCPheaders) in an IPv4 packet via the link layer endpoint using the +// provided source and destination IPv4 addresses. +func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: c.BuildSegmentWithAddrs(payload, h, src, dst), + }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendAck sends an ACK packet. @@ -389,6 +430,8 @@ func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlock // verifies that the packet packet payload of packet matches the slice // of data indicated by offset & size. func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { + c.t.Helper() + c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) } @@ -397,6 +440,8 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { // data indicated by offset & size and skips optlen bytes in addition to the IP // TCP headers when comparing the data. func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { + c.t.Helper() + b := c.GetPacket() checker.IPv4(c.t, b, checker.PayloadLen(size+header.TCPMinimumSize+optlen), @@ -419,6 +464,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op // data indicated by offset & size. It returns true if a packet was received and // processed. func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { + c.t.Helper() + b := c.GetPacketNonBlocking() if b == nil { return false @@ -450,11 +497,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.EP.SetSockOpt(v); err != nil { + if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } @@ -462,28 +505,36 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // GetV6Packet reads a single packet from the link layer endpoint of the context // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b + c.t.Helper() - case <-time.After(2 * time.Second): + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil + } + + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() - return nil + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b } // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of // the context. func (c *Context) SendV6Packet(payload []byte, h *Headers) { + c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) +} + +// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer +// endpoint of the context using the provided source and destination IPv6 +// addresses. +func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -494,8 +545,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { PayloadLength: uint16(header.TCPMinimumSize + len(payload)), NextHeader: uint8(tcp.ProtocolNumber), HopLimit: 65, - SrcAddr: TestV6Addr, - DstAddr: StackV6Addr, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. @@ -511,14 +562,17 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + }) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) } // CreateConnected creates a connected TCP endpoint. @@ -535,6 +589,8 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf // // PreCondition: c.EP must already be created. func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { + c.t.Helper() + // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) @@ -1051,7 +1107,11 @@ func (c *Context) SACKEnabled() bool { // SetGSOEnabled enables or disables generic segmentation offload. func (c *Context) SetGSOEnabled(enable bool) { - c.linkEP.GSO = enable + if enable { + c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + } else { + c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + } } // MSSWithoutOptions returns the value for the MSS used by the stack when no @@ -1059,3 +1119,9 @@ func (c *Context) SetGSOEnabled(enable bool) { func (c *Context) MSSWithoutOptions() uint16 { return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) } + +// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no +// options are in use for IPv6 packets. +func (c *Context) MSSWithoutOptionsV6() uint16 { + return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) +} diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index c70525f27..7981d469b 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) { // cleanup frees all resources associated with the timer. func (t *timer) cleanup() { t.timer.Stop() + *t = timer{} } // checkExpiration checks if the given timer has actually expired, it should be diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go new file mode 100644 index 000000000..dbd6dff54 --- /dev/null +++ b/pkg/tcpip/transport/tcp/timer_test.go @@ -0,0 +1,47 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sleep" +) + +func TestCleanup(t *testing.T) { + const ( + timerDurationSeconds = 2 + isAssertedTimeoutSeconds = timerDurationSeconds + 1 + ) + + tmr := timer{} + w := sleep.Waker{} + tmr.init(&w) + tmr.enable(timerDurationSeconds * time.Second) + tmr.cleanup() + + if want := (timer{}); tmr != want { + t.Errorf("got tmr = %+v, want = %+v", tmr, want) + } + + // The waker should not be asserted. + for i := 0; i < isAssertedTimeoutSeconds; i++ { + time.Sleep(time.Second) + if w.IsAsserted() { + t.Fatalf("waker asserted unexpectedly") + } + } +} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 43fcc27f0..3ad6994a7 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "tcpconntrack", srcs = ["tcp_conntrack.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 93712cd45..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { @@ -311,17 +316,7 @@ type stream struct { // the window is zero, if it's a packet with no payload and sequence number // equal to una. func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - wnd := s.una.Size(s.end) - if wnd == 0 { - return segLen == 0 && segSeq == s.una - } - - // Make sure [segSeq, seqSeq+segLen) is non-empty. - if segLen == 0 { - segLen = 1 - } - - return seqnum.Overlap(s.una, wnd, segSeq, segLen) + return header.Acceptable(segSeq, segLen, s.una, s.end) } // closed determines if the stream has already been closed. This happens when @@ -347,3 +342,16 @@ func logicalLen(tcp header.TCP) seqnum.Size { } return l } + +// IsEmpty returns true if tcb is not initialized. +func (t *TCB) IsEmpty() bool { + if t.inbound != (stream{}) || t.outbound != (stream{}) { + return false + } + + if t.firstFin != nil || t.state != ResultDrop { + return false + } + + return true +} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c9460aa0d..b5d2d0ba6 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -25,15 +24,15 @@ go_library( "protocol.go", "udp_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/waiter", @@ -59,11 +58,3 @@ go_test( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "udp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 91c8487f3..73608783c 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,12 +15,14 @@ package udp import ( - "sync" + "fmt" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -29,11 +31,11 @@ import ( type udpPacket struct { udpPacketEntry senderAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` + // tos stores either the receiveTOS or receiveTClass value. + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -80,6 +82,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -93,6 +96,7 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int + sndBufSizeMax int state EndpointState route stack.Route `state:"manual"` dstPort uint16 @@ -102,14 +106,34 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID multicastLoop bool - reusePort bool + portFlags ports.Flags bindToDevice tcpip.NICID broadcast bool + noChecksum bool + + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` + + // Values used to reserve a port or register a transport endpoint. + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags // sendTOS represents IPv4 TOS or IPv6 TrafficClass, // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 + // receiveTOS determines if the incoming IPv4 TOS header field is passed + // as ancillary data to ControlMessages on Read. + receiveTOS bool + + // receiveTClass determines if the incoming IPv6 TClass header field is + // passed as ancillary data to ControlMessages on Read. + receiveTClass bool + + // receiveIPPacketInfo determines if the packet info is returned by Read. + receiveIPPacketInfo bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -127,6 +151,9 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // +stateify savable @@ -136,7 +163,7 @@ type multicastMembership struct { } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { - return &endpoint{ + e := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, @@ -158,9 +185,42 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue multicastTTL: 1, multicastLoop: true, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, + sndBufSizeMax: 32 * 1024, state: StateInitial, + uniqueID: s.UniqueID(), } + + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + e.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + e.rcvBufSizeMax = rs.Default + } + + return e +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + err := e.lastError + e.lastError = nil + return err +} + +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() } // Close puts the endpoint in a closed state and frees all resources @@ -171,8 +231,10 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } for _, mem := range e.multicastMemberships { @@ -203,14 +265,13 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + e.rcvMu.Lock() if e.rcvList.Empty() { @@ -232,7 +293,29 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess *addr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + cm := tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + } + e.mu.RLock() + receiveTOS := e.receiveTOS + receiveTClass := e.receiveTClass + receiveIPPacketInfo := e.receiveIPPacketInfo + e.mu.RUnlock() + if receiveTOS { + cm.HasTOS = true + cm.TOS = p.tos + } + if receiveTClass { + cm.HasTClass = true + // Although TClass is an 8-bit value it's read in the CMsg as a uint32. + cm.TClass = uint32(p.tos) + } + if receiveIPPacketInfo { + cm.HasIPPacketInfo = true + cm.PacketInfo = p.packetInfo + } + return p.data.ToView(), cm, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -278,7 +361,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -286,20 +369,20 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr } if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicid == 0 { - nicid = e.multicastNICID + if nicID == 0 { + nicID = e.multicastNICID } - if localAddr == "" && nicid == 0 { + if localAddr == "" && nicID == 0 { localAddr = e.multicastAddr } } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { return stack.Route{}, 0, err } - return r, nicid, nil + return r, nicID, nil } // Write writes data to the endpoint's peer. This method does not block @@ -328,6 +411,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return 0, nil, err + } + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -356,58 +443,68 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } var route *stack.Route + var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) var dstPort uint16 if to == nil { route = &e.route dstPort = e.dstPort - - if route.IsResolutionRequired() { - // Promote lock to exclusive if using a shared route, given that it may need to - // change in Route.Resolve() call below. + resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) { + // Promote lock to exclusive if using a shared route, given that it may + // need to change in Route.Resolve() call below. e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() // Recheck state after lock was re-acquired. if e.state != StateConnected { - return 0, nil, tcpip.ErrInvalidEndpointState + err = tcpip.ErrInvalidEndpointState + } + if err == nil && route.IsResolutionRequired() { + ch, err = route.Resolve(waker) + } + + e.mu.Unlock() + e.mu.RLock() + + // Recheck state after lock was re-acquired. + if e.state != StateConnected { + err = tcpip.ErrInvalidEndpointState } + return } } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC + nicID := to.NIC if e.BindNICID != 0 { - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.BindNICID + nicID = e.BindNICID } - if to.Addr == header.IPv4Broadcast && !e.broadcast { - return 0, nil, tcpip.ErrBroadcastDisabled - } - - netProto, err := e.checkV4Mapped(to, false) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - r, _, err := e.connectRoute(nicid, *to, netProto) + r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { return 0, nil, err } defer r.Release() route = &r - dstPort = to.Port + dstPort = dst.Port + resolve = route.Resolve + } + + if !e.broadcast && route.IsOutboundBroadcast() { + return 0, nil, tcpip.ErrBroadcastDisabled } if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { + if ch, err := resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { return 0, ch, tcpip.ErrNoLinkAddress } @@ -433,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -444,14 +541,54 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { - return nil -} +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v + e.mu.Unlock() + + case tcpip.MulticastLoopOption: + e.mu.Lock() + e.multicastLoop = v + e.mu.Unlock() + + case tcpip.NoChecksumOption: + e.mu.Lock() + e.noChecksum = v + e.mu.Unlock() + + case tcpip.ReceiveTOSOption: + e.mu.Lock() + e.receiveTOS = v + e.mu.Unlock() + + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrNotSupported + } + + e.mu.Lock() + e.receiveTClass = v + e.mu.Unlock() + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() + + case tcpip.ReusePortOption: + e.mu.Lock() + e.portFlags.LoadBalanced = v + e.mu.Unlock() -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -466,24 +603,94 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - e.v6only = v != 0 + e.v6only = v + } + + return nil +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) e.mu.Unlock() - case tcpip.MulticastTTLOption: + case tcpip.IPv4TOSOption: e.mu.Lock() - e.multicastTTL = uint8(v) + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) + } + + if v < rs.Min { + v = rs.Min + } + if v > rs.Max { + v = rs.Max + } + + e.mu.Lock() + e.rcvBufSizeMax = v e.mu.Unlock() + return nil + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) + } + if v < ss.Min { + v = ss.Min + } + if v > ss.Max { + v = ss.Max + } + + e.mu.Lock() + e.sndBufSizeMax = v + e.mu.Unlock() + return nil + } + + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.MulticastInterfaceOption: e.mu.Lock() defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa, false) + fa, netProto, err := e.checkV4MappedLocked(fa) if err != nil { return err } @@ -601,56 +808,124 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] - case tcpip.MulticastLoopOption: + case tcpip.BindToDeviceOption: + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice + } e.mu.Lock() - e.multicastLoop = bool(v) + e.bindToDevice = id e.mu.Unlock() - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() + case tcpip.SocketDetachFilterOption: + return nil + } + return nil +} - case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.BroadcastOption: + e.mu.RLock() + v := e.broadcast + e.mu.RUnlock() + return v, nil + + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.MulticastLoopOption: + e.mu.RLock() + v := e.multicastLoop + e.mu.RUnlock() + return v, nil + + case tcpip.NoChecksumOption: + e.mu.RLock() + v := e.noChecksum + e.mu.RUnlock() + return v, nil + + case tcpip.ReceiveTOSOption: + e.mu.RLock() + v := e.receiveTOS + e.mu.RUnlock() + return v, nil + + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrNotSupported } - for nicid, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicid - return nil - } + + e.mu.RLock() + v := e.receiveTClass + e.mu.RUnlock() + return v, nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil + + case tcpip.ReuseAddressOption: + e.mu.RLock() + v := e.portFlags.MostRecent + e.mu.RUnlock() + + return v, nil + + case tcpip.ReusePortOption: + e.mu.RLock() + v := e.portFlags.LoadBalanced + e.mu.RUnlock() + + return v, nil + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption } - return tcpip.ErrUnknownDevice - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v != 0 - e.mu.Unlock() + e.mu.RLock() + v := e.v6only + e.mu.RUnlock() - return nil + return v, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } +} +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + switch opt { case tcpip.IPv4TOSOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + + case tcpip.MulticastTTLOption: e.mu.Lock() - e.sendTOS = uint8(v) + v := int(e.multicastTTL) e.mu.Unlock() - return nil - } - return nil -} + return v, nil -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { - switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 e.rcvMu.Lock() @@ -663,7 +938,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { case tcpip.SendBufferSizeOption: e.mu.Lock() - v := e.sndBufSize + v := e.sndBufSizeMax e.mu.Unlock() return v, nil @@ -672,45 +947,23 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { v := e.rcvBufSizeMax e.rcvMu.Unlock() return v, nil - } - return -1, tcpip.ErrUnknownProtocolOption + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - return nil - - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.TTLOption: - e.mu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() - return nil - - case *tcpip.MulticastTTLOption: - e.mu.Lock() - *o = tcpip.MulticastTTLOption(e.multicastTTL) - e.mu.Unlock() - return nil - + return e.takeLastError() case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -718,87 +971,43 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.multicastAddr, } e.mu.Unlock() - return nil - - case *tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - - *o = tcpip.MulticastLoopOption(v) - return nil - - case *tcpip.ReusePortOption: - e.mu.RLock() - v := e.reusePort - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = tcpip.BindToDeviceOption("") - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.BroadcastOption: - e.mu.RLock() - v := e.broadcast + *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.IPv4TOSOption: - e.mu.RLock() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } + return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error { - // Allocate a buffer for the UDP header. - hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: data, + }) + pkt.Owner = owner - // Initialize the header. - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ SrcPort: localPort, DstPort: remotePort, Length: length, }) - // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { xsum = header.Checksum(v, xsum) @@ -809,7 +1018,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: ProtocolNumber, + TTL: ttl, + TOS: tos, + }, pkt); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } @@ -819,36 +1032,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if len(addr.Addr) == 0 { - return netProto, nil - } - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - - // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.ID.LocalAddress) == 16 { - return 0, tcpip.ErrNetworkUnreachable - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + if err != nil { + return tcpip.FullAddress{}, 0, err } - - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -859,7 +1050,15 @@ func (e *endpoint) Disconnect() *tcpip.Error { if e.state != StateConnected { return nil } - id := stack.TransportEndpointID{} + var ( + id stack.TransportEndpointID + btd tcpip.NICID + ) + + // We change this value below and we need the old value to unregister + // the endpoint. + boundPortFlags := e.boundPortFlags + // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error @@ -867,21 +1066,24 @@ func (e *endpoint) Disconnect() *tcpip.Error { LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } e.state = StateBound + boundPortFlags = e.boundPortFlags } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) + e.boundPortFlags = ports.Flags{} } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id + e.boundBindToDevice = btd e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -891,10 +1093,6 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr, false) - if err != nil { - return err - } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -903,7 +1101,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC var localPort uint16 switch e.state { case StateInitial: @@ -913,16 +1111,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { break } - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.BindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } - r, nicid, err := e.connectRoute(nicid, addr, netProto) + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err } @@ -950,20 +1153,23 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } - id, err = e.registerWithStack(nicid, netProtos, id) + oldPortFlags := e.boundPortFlags + + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) } e.ID = id + e.boundBindToDevice = btd e.route = r.Clone() e.dstPort = addr.Port - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos e.state = StateConnected @@ -1018,20 +1224,22 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { - return id, err + return id, e.bindToDevice, err } id.LocalPort = port } + e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) + e.boundPortFlags = ports.Flags{} } - return id, err + return id, e.bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { @@ -1041,7 +1249,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, true) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -1057,11 +1265,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } } - nicid := addr.NIC + nicID := addr.NIC if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { // A local unicast address was specified, verify that it's valid. - nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicid == 0 { + nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nicID == 0 { return tcpip.ErrBadLocalAddress } } @@ -1070,13 +1278,14 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, err = e.registerWithStack(nicid, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id - e.RegisterNICID = nicid + e.boundBindToDevice = btd + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos // Mark endpoint as bound. @@ -1111,9 +1320,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() + addr := e.ID.LocalAddress + if e.state == StateConnected { + addr = e.route.LocalAddress + } + return tcpip.FullAddress{ NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, + Addr: addr, Port: e.ID.LocalPort, }, nil } @@ -1154,22 +1368,47 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.TransportHeader().View()) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - vv.TrimFront(header.UDPMinimumSize) + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + + // Verify checksum unless RX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value means + // the transmitter omitted the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && + (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + if hdr.CalculateChecksum(xsum) != 0xffff { + // Checksum Error. + e.stack.Stats().UDP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() + return + } + } - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() @@ -1188,18 +1427,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &udpPacket{ + packet := &udpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } - pkt.data = vv.Clone(pkt.views[:]) - e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + packet.data = pkt.Data + e.rcvList.PushBack(packet) + e.rcvBufSize += pkt.Data.Size() + + // Save any useful information from the network header to the packet. + switch r.NetProto { + case header.IPv4ProtocolNumber: + packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() + case header.IPv6ProtocolNumber: + packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS() + } - pkt.timestamp = e.stack.NowNanoseconds() + // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast + // address. packetInfo.LocalAddr should hold a unicast address that can be + // used to respond to the incoming packet. + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.LocalAddress + packet.packetInfo.NIC = r.NICID() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() @@ -1210,7 +1463,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { + if typ == stack.ControlPortUnreachable { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state == StateConnected { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + e.lastError = tcpip.ErrConnectionRefused + } + } } // State implements tcpip.Endpoint.State. @@ -1234,6 +1498,13 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } + +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index b227e353b..851e6b635 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,6 +37,24 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = tcpip.StringToError(s) +} + // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). @@ -69,6 +87,9 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + e.stack = s for _, m := range e.multicastMemberships { @@ -109,7 +130,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // pass it to the reservation machinery. id := e.ID e.ID.LocalPort = 0 - e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index d399ec722..c67e0ba95 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -16,7 +16,6 @@ package udp import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, id: id, - vv: vv, + pkt: pkt, }) return true @@ -62,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - vv buffer.VectorisedView + pkt *stack.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that @@ -74,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() return nil, err } @@ -83,6 +82,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort ep.RegisterNICID = r.route.NICID() + ep.boundPortFlags = ep.portFlags ep.state = StateConnected @@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.vv) + ep.HandlePacket(r.route, r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 5c3358a5e..63d4bed7c 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -32,9 +32,24 @@ import ( const ( // ProtocolNumber is the udp protocol number. ProtocolNumber = header.UDPProtocolNumber + + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4KiB bytes. + + // DefaultSendBufferSize is the default size of the send buffer for + // an endpoint. + DefaultSendBufferSize = 32 << 10 // 32KiB + + // DefaultReceiveBufferSize is the default size of the receive buffer + // for an endpoint. + DefaultReceiveBufferSize = 32 << 10 // 32KiB + + // MaxBufferSize is the largest size a receive/send buffer can grow to. + MaxBufferSize = 4 << 20 // 4MiB ) -type protocol struct{} +type protocol struct { +} // Number returns the udp protocol number. func (*protocol) Number() tcpip.TransportProtocolNumber { @@ -66,10 +81,9 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + hdr := header.UDP(pkt.TransportHeader().View()) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -116,28 +130,30 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } - // The buffers used by vv and netHeader may be used elsewhere - // in the system. For example, a raw or packet socket may use - // what UDP considers an unreachable destination. Thus we deep - // copy vv and netHeader to prevent multiple ownership and SR - // errors. - newNetHeader := make(buffer.View, len(netHeader)) - copy(newNetHeader, netHeader) - payload := buffer.NewVectorisedView(len(newNetHeader), []buffer.View{newNetHeader}) - payload.Append(vv.ToView().ToVectorisedView()) + // The buffers used by pkt may be used elsewhere in the system. + // For example, a raw or packet socket may use what UDP + // considers an unreachable destination. Thus we deep copy pkt + // to prevent multiple ownership and SR errors. + newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...) + newHeader = append(newHeader, pkt.TransportHeader().View()...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4DstUnreachable) - pkt.SetCode(header.ICMPv4PortUnreachable) - pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -158,34 +174,50 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + payloadLen := len(network) + len(transport) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport}) + payload.Append(pkt.Data) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) - pkt.SetType(header.ICMPv6DstUnreachable) - pkt.SetCode(header.ICMPv6PortUnreachable) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) } return true } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + return ok +} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b724d788c..f87d99d5a 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "context" "fmt" "math/rand" "testing" @@ -56,6 +57,7 @@ const ( multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast + testTOS = 0x80 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -81,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -115,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -166,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -197,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -222,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -233,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -244,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -255,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -273,11 +293,16 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - - s := stack.New(stack.Options{ + return newDualTestContextWithOptions(t, mtu, stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, }) +} + +func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { + t.Helper() + + s := stack.New(options) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -285,15 +310,15 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { wep = sniffer.New(ep) } if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } s.SetRouteTable([]tcpip.Route{ @@ -335,12 +360,12 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + c.t.Fatalf("SetSockOptBool failed: %s", err) } } else if flow.isBroadcast() { - if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { - c.t.Fatal("SetSockOpt failed:", err) + if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + c.t.Fatalf("SetSockOptBool failed: %s", err) } } } @@ -351,30 +376,30 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - h := flow.header4Tuple(outgoing) - checkers := append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b - - case <-time.After(2 * time.Second): + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil } - return nil + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + h := flow.header4Tuple(outgoing) + checkers = append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) + return b } // injectPacket creates a packet of the given flow and with the given payload, @@ -384,24 +409,30 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) } else { - c.injectV6Packet(payload, &h, true /* valid */) + buf := c.buildV6Packet(payload, &h) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) } } -// injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV6Packet creates a V6 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, @@ -411,16 +442,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) - l := uint16(header.UDPMinimumSize + len(payload)) - if !valid { - // Change the UDP payload length to corrupt the header - // as requested by the caller. - l++ - } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: l, + Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. @@ -430,23 +455,22 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + return buf } -// injectV4Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV4Packet creates a V4 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, + TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), @@ -470,8 +494,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + return buf } func newPayload() []byte { @@ -493,50 +516,46 @@ func TestBindToDeviceOption(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) + opts := stack.NICOptions{Name: "my_device"} + if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { + t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } @@ -545,8 +564,8 @@ func TestBindToDeviceOption(t *testing.T) { // testReadInternal sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its -// correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { +// correctness including any additional checker functions provided. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { c.t.Helper() payload := newPayload() @@ -561,12 +580,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) + v, cm, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, _, err = c.ep.Read(&addr) + v, cm, err = c.ep.Read(&addr) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -592,22 +611,28 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr) } // Check the payload. if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + + // Run any checkers against the ControlMessages. + for _, f := range checkers { + f(c.t, cm) + } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it // into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// its correctness including any additional checker functions provided. +func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) } // testFailingRead sends a packet of the given test flow into the stack by @@ -625,7 +650,7 @@ func TestBindEphemeralPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } } @@ -636,19 +661,19 @@ func TestBindReservedPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } addr, err := c.ep.GetLocalAddress() if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) + t.Fatalf("GetLocalAddress failed: %s", err) } // We can't bind the address reserved by the connected endpoint above. { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { @@ -659,7 +684,7 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint @@ -669,7 +694,7 @@ func TestBindReservedPort(t *testing.T) { } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() @@ -679,11 +704,11 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() } @@ -696,7 +721,7 @@ func TestV4ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -711,7 +736,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { // Bind to v4 mapped wildcard. if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -726,7 +751,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -741,13 +766,59 @@ func TestV6ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. testRead(c, unicastV6) } +// TestV4ReadSelfSource checks that packets coming from a local IP address are +// correctly dropped when handleLocal is true and not otherwise. +func TestV4ReadSelfSource(t *testing.T) { + for _, tt := range []struct { + name string + handleLocal bool + wantErr *tcpip.Error + wantInvalidSource uint64 + }{ + {"HandleLocal", false, nil, 0}, + {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, + } { + t.Run(tt.name, func(t *testing.T) { + c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + HandleLocal: tt.handleLocal, + }) + defer c.cleanup() + + c.createEndpointForFlow(unicastV4) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + h.srcAddr = h.dstAddr + + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { + t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) + } + + if _, _, err := c.ep.Read(nil); err != tt.wantErr { + t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr) + } + }) + } +} + func TestV4ReadOnV4(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -756,7 +827,7 @@ func TestV4ReadOnV4(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -819,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -894,7 +1019,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... payload := buffer.View(newPayload()) n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { - c.t.Fatalf("Write failed: %v", err) + c.t.Fatalf("Write failed: %s", err) } if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) @@ -944,7 +1069,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } p := testDualWrite(c) @@ -961,7 +1086,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV6) @@ -982,7 +1107,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV4in6) @@ -1009,7 +1134,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Write to v6 address. @@ -1024,7 +1149,7 @@ func TestV6WriteOnConnected(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV6) @@ -1038,7 +1163,7 @@ func TestV4WriteOnConnected(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV4) @@ -1173,7 +1298,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testRead(c, unicastV4) @@ -1184,6 +1309,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { } } +func TestReadIPPacketInfo(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedLocalAddr tcpip.Address + expectedDestAddr tcpip.Address + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedLocalAddr: stackAddr, + expectedDestAddr: stackAddr, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastAddr, + expectedDestAddr: multicastAddr, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: broadcastAddr, + expectedDestAddr: broadcastAddr, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedLocalAddr: stackV6Addr, + expectedDestAddr: stackV6Addr, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastV6Addr, + expectedDestAddr: multicastV6Addr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err) + } + } + + if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { + t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) + } + + testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: 1, + LocalAddr: test.expectedLocalAddr, + DestinationAddr: test.expectedDestAddr, + })) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1198,6 +1422,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } +func TestNoChecksum(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Disable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + // This option is effective on IPv4 only. + testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) + + // Enable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) + }) + } +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1207,8 +1455,8 @@ func TestTTL(t *testing.T) { c.createEndpointForFlow(flow) const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { + c.t.Fatalf("SetSockOptInt failed: %s", err) } var wantTTL uint8 @@ -1221,10 +1469,10 @@ func TestTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } + ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + })) wantTTL = ep.DefaultTTL() ep.Close() } @@ -1244,8 +1492,8 @@ func TestSetTTL(t *testing.T) { c.createEndpointForFlow(flow) - if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } var p stack.NetworkProtocol @@ -1254,10 +1502,10 @@ func TestSetTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } + ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + })) ep.Close() testWrite(c, flow, checker.TTL(wantTTL)) @@ -1267,7 +1515,7 @@ func TestSetTTL(t *testing.T) { } } -func TestTOSV4(t *testing.T) { +func TestSetTOS(t *testing.T) { for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) @@ -1275,26 +1523,27 @@ func TestTOSV4(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 - var v tcpip.IPv4TOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + const tos = testTOS + v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) } testWrite(c, flow, checker.TOS(tos, 0)) @@ -1302,7 +1551,7 @@ func TestTOSV4(t *testing.T) { } } -func TestTOSV6(t *testing.T) { +func TestSetTClass(t *testing.T) { for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) @@ -1310,33 +1559,96 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 - var v tcpip.IPv6TrafficClassOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + const tClass = testTOS + v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - c.t.Errorf("SetSockOpt failed: %s", err) + if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { + c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tClass { + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) } - testWrite(c, flow, checker.TOS(tos, 0)) + // The header getter for TClass is called TOS, so use that checker. + testWrite(c, flow, checker.TOS(tClass, 0)) }) } } +func TestReceiveTosTClass(t *testing.T) { + testCases := []struct { + name string + getReceiveOption tcpip.SockOptBool + tests []testFlow + }{ + {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, + {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + } + for _, testCase := range testCases { + for _, flow := range testCase.tests { + t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + option := testCase.getReceiveOption + name := testCase.name + + // Verify that setting and reading the option works. + v, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) + } + + want := true + if err := c.ep.SetSockOptBool(option, want); err != nil { + c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) + } + + got, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) + } + + if got != want { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) + } + + // Verify that the correct received TOS or TClass is handed through as + // ancillary data to the ControlMessages struct. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + switch option { + case tcpip.ReceiveTClassOption: + testRead(c, flow, checker.ReceiveTClass(testTOS)) + case tcpip.ReceiveTOSOption: + testRead(c, flow, checker.ReceiveTOS(testTOS)) + default: + t.Fatalf("unknown test variant: %s", name) + } + }) + } + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1375,12 +1687,12 @@ func TestMulticastInterfaceOption(t *testing.T) { Port: stackPort, } if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } } if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOpt failed: %s", err) } // Verify multicast interface addr and NIC were set correctly. @@ -1388,7 +1700,7 @@ func TestMulticastInterfaceOption(t *testing.T) { ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} var ifoptGot tcpip.MulticastInterfaceOption if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) + c.t.Fatalf("GetSockOpt failed: %s", err) } if ifoptGot != ifoptWant { c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) @@ -1431,48 +1743,51 @@ func TestV4UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize - } + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) } }) } @@ -1505,54 +1820,57 @@ func TestV6UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) } }) } } // TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats get incremented. +// global and endpoint stats are incremented. func TestIncrementMalformedPacketsReceived(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1560,20 +1878,271 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } payload := newPayload() - c.t.Helper() h := unicastV6.header4Tuple(incoming) - c.injectV6Packet(payload, &h, false /* !valid */) + buf := c.buildV6Packet(payload, &h) - var want uint64 = 1 + // Invalidate the UDP header length field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetLength(u.Length() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) } if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } +} + +// TestShortHeader verifies that when a packet with a too-short UDP header is +// received, the malformed received global stat gets incremented. +func TestShortHeader(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + h := unicastV6.header4Tuple(incoming) + + // Allocate a buffer for an IPv6 and too-short UDP header. + const udpSize = header.UDPMinimumSize - 1 + buf := buffer.NewView(header.IPv6MinimumSize + udpSize) + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, + }) + + // Initialize the UDP header. + udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, + Length: header.UDPMinimumSize, + }) + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + // Copy all but the last byte of the UDP header into the packet. + copy(buf[header.IPv6MinimumSize:], udpHdr) + + // Inject packet. + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + } +} + +// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + + // Invalidate the UDP header checksum field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV4 verifies if the checksum value is zero, global and +// endpoint states are *not* incremented (UDP checksum is optional on IPv4). +func TestChecksumZeroV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + const want = 0 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV6 verifies if the checksum value is zero, global and +// endpoint states are incremented (UDP checksum is *not* optional on IPv6). +func TestChecksumZeroV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) } } @@ -1587,15 +2156,15 @@ func TestShutdownRead(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingRead(c, unicastV6, true /* expectReadError */) @@ -1618,11 +2187,11 @@ func TestShutdownWrite(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) @@ -1664,3 +2233,192 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) } } + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const nicID1 = 1 + + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + requiresBroadcastOpt bool + }{ + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + requiresBroadcastOpt: true, + }, + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + requiresBroadcastOpt: false, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + requiresBroadcastOpt: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) + } + defer ep.Close() + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + to := tcpip.FullAddress{ + Addr: test.remoteAddr, + Port: 80, + } + opts := tcpip.WriteOptions{To: &to} + expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + if !test.requiresBroadcastOpt { + expectedErrWithoutBcastOpt = nil + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != nil { + t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + }) + } +} |